diff --git a/invenio_utilities_tuw/cli/drafts.py b/invenio_utilities_tuw/cli/drafts.py index f43661b3a7a42a42812b129e17bf08dbab93bc6d..13b09358751d4b1c8304273bda473daecc4f3699 100644 --- a/invenio_utilities_tuw/cli/drafts.py +++ b/invenio_utilities_tuw/cli/drafts.py @@ -21,6 +21,7 @@ from invenio_access.permissions import system_identity from invenio_db import db from invenio_rdm_records.proxies import current_rdm_records_service as service from invenio_records_resources.services.errors import PermissionDeniedError +from invenio_records_resources.services.uow import UnitOfWork from ..utils import get_identity_for_user, get_user_by_identifier from .options import ( @@ -33,6 +34,8 @@ from .options import ( option_vanity_pid, ) from .utils import ( + auto_increase_bucket_limits, + collect_file_paths, convert_to_recid, create_record_from_metadata, patch_metadata, @@ -41,41 +44,6 @@ from .utils import ( ) -def file_exists(path, filename): - """Check if the file exists in the given path.""" - return isfile(join(path, filename)) - - -def auto_increase_bucket_limits(bucket, filepaths, to_unlimited=False): - """Dynamically increase the bucket quoat if necessary.""" - # see what the file sizes look like - file_sizes = [os.path.getsize(filepath) for filepath in filepaths] - sum_sizes = sum(file_sizes) - max_size = max(file_sizes) - - if bucket.quota_left is not None: - if to_unlimited: - bucket.quota_size = None - - else: - # see how big the files are, and compare it against the bucket's quota - req_extra_quota = sum_sizes - bucket.quota_left - - # if we need some extra quota, increase it - if req_extra_quota > 0: - bucket.quota_size += req_extra_quota - - if bucket.max_file_size and bucket.max_file_size < max_size: - # do similar checks for the maximum file size - if to_unlimited: - bucket.max_file_size = None - else: - bucket.max_file_size = max_size - - # make changes known - db.session.flush() - - @click.group() def drafts(): """Management commands for creation and publication of drafts.""" @@ -162,7 +130,7 @@ def create_draft(metadata_path, publish, user, owner, vanity_pid): content = os.listdir(deposit_files_path) file_names = [ - basename(fn) for fn in content if file_exists(deposit_files_path, fn) + basename(fn) for fn in content if isfile(join(deposit_files_path, fn)) ] for fn in file_names: @@ -172,7 +140,7 @@ def create_draft(metadata_path, publish, user, owner, vanity_pid): ignored = [ basename(fn) for fn in content - if not file_exists(deposit_files_path, fn) + if not isfile(join(deposit_files_path, fn)) ] msg = f"ignored in '{deposit_files_path}': {ignored}" click.secho(msg, fg="yellow", err=True) @@ -309,32 +277,8 @@ def add_files(filepaths, pid, pid_type, user): file_service = service.draft_files draft = service.read_draft(id_=recid, identity=identity)._record - if not draft.files.enabled: - draft.files.enabled = True - draft.commit() - - paths = [] - for file_path in filepaths: - if isdir(file_path): - # add all files (no recursion into sub-dirs) from the directory - content = os.listdir(file_path) - file_names = [basename(fn) for fn in content if file_exists(file_path, fn)] - - if len(content) != len(file_names): - ignored = [ - basename(fn) for fn in content if not file_exists(file_path, fn) - ] - msg = f"ignored in '{file_path}': {ignored}" - click.secho(msg, fg="yellow", err=True) - - paths_ = [join(file_path, fn) for fn in file_names] - paths.extend(paths_) - - elif isfile(file_path): - paths.append(file_path) - - # make sure that the files fit in the bucket, and add them - auto_increase_bucket_limits(draft.files.bucket, paths) + # check if any of the files' basenames are duplicate + paths = collect_file_paths(filepaths) keys = [basename(fp) for fp in paths] if len(set(keys)) != len(keys): click.secho( @@ -343,27 +287,67 @@ def add_files(filepaths, pid, pid_type, user): sys.exit(1) # check for existing duplicates - files = list(file_service.list_files(id_=recid, identity=identity).entries()) - existing_file_keys = [e["key"] for e in files] + existing_file_keys = list(draft.files.entries.keys()) if any([k for k in keys if k in existing_file_keys]): click.secho( "aborting: reuse of existing file names detected", fg="yellow", err=True ) sys.exit(1) - # if all went well so far, continue on - file_service.init_files( - id_=recid, identity=identity, data=[{"key": basename(fp)} for fp in paths] - ) - for fp in paths: - fn = basename(fp) - with open(fp, "rb") as deposit_file: - file_service.set_file_content( - id_=recid, file_key=fn, identity=identity, stream=deposit_file - ) - file_service.commit_file(id_=recid, file_key=fn, identity=identity) + uow = UnitOfWork(db.session) + try: + # prepare the draft's file manager and bucket for files + bucket_was_locked = draft.files.bucket.locked + files_were_enabled = draft.files.enabled + draft.files.bucket.locked = False + if not files_were_enabled: + draft.files.enabled = True + draft.commit() + db.session.flush() + + auto_increase_bucket_limits(draft.files.bucket, paths) + file_service.init_files( + id_=recid, + identity=identity, + data=[{"key": basename(fp)} for fp in paths], + uow=uow, + ) - click.secho(recid, fg="green") + for fp in paths: + fn = basename(fp) + with open(fp, "rb") as deposit_file: + file_service.set_file_content( + id_=recid, + file_key=fn, + identity=identity, + stream=deposit_file, + uow=uow, + ) + file_service.commit_file(id_=recid, file_key=fn, identity=identity, uow=uow) + click.secho(recid, fg="green") + + # if the draft has already been published, we may need to enable the files for + # the published record as well + if draft.is_published and not files_were_enabled: + record = service.record_cls.get_record(draft.id) + record.files.enabled = True + record.commit() + + uow.commit() + + except Exception as e: + uow.rollback() + if draft.files.enabled != files_were_enabled: + draft.files.enabled = files_were_enabled + draft.commit() + + click.secho(f"aborted due to error: {e}", fg="red", err=True) + + finally: + if bucket_was_locked != draft.files.bucket.locked: + draft.files.bucket.locked = bucket_was_locked + + db.session.commit() @files.command("remove") diff --git a/invenio_utilities_tuw/cli/utils.py b/invenio_utilities_tuw/cli/utils.py index 9ef6dd272ab819df6a27e1f8a9de3594a3001067..e192b721ba4db881c74750a00c54b66ac10ef524 100644 --- a/invenio_utilities_tuw/cli/utils.py +++ b/invenio_utilities_tuw/cli/utils.py @@ -9,7 +9,10 @@ """Utilities for the CLI commands.""" import json +import os +from os.path import basename, isdir, isfile, join +import click from invenio_db import db from invenio_pidstore.errors import PIDAlreadyExists from invenio_pidstore.models import PersistentIdentifier @@ -126,3 +129,60 @@ def is_owned_by(user, record): """Check if the record is owned by the given user.""" owner = record.parent.access.owned_by return owner and owner.owner_id == user.id + + +def collect_file_paths(paths): + """Collect file paths from the given paths. + + If one of the given paths is a directory, its path will be replaced with the + paths of all files it contains. + If it contains any subdirectories however, then it will be skipped instead. + """ + paths_ = [] + for path in paths: + if isdir(path): + # add all files (no recursion into sub-dirs) from the directory + content = os.listdir(path) + file_names = [basename(fn) for fn in content if isfile(join(path, fn))] + + if len(content) != len(file_names): + ignored = [basename(fn) for fn in content if not isfile(join(path, fn))] + msg = f"ignored in '{path}': {ignored}" + click.secho(msg, fg="yellow", err=True) + + paths_ = [join(path, fn) for fn in file_names] + paths_.extend(paths_) + + elif isfile(path): + paths_.append(path) + + return paths_ + + +def auto_increase_bucket_limits(bucket, filepaths, to_unlimited=False): + """Dynamically increase the bucket quota if necessary.""" + file_sizes = [os.path.getsize(filepath) for filepath in filepaths] + sum_sizes = sum(file_sizes) + max_size = max(file_sizes) + + if bucket.quota_left is not None: + if to_unlimited: + bucket.quota_size = None + + else: + # see how big the files are, and compare it against the bucket's quota + req_extra_quota = sum_sizes - bucket.quota_left + + # if we need some extra quota, increase it + if req_extra_quota > 0: + bucket.quota_size += req_extra_quota + + if bucket.max_file_size and bucket.max_file_size < max_size: + # do similar checks for the maximum file size + if to_unlimited: + bucket.max_file_size = None + else: + bucket.max_file_size = max_size + + # make changes known + db.session.flush()