diff --git a/formatscaper/formatscaper.py b/formatscaper/formatscaper.py index 8c0d17d1f8a41ac4855c2ac7ad8c1421cb3c6d07..0895b1f592cf70e2ee37bd363721ee716323e377 100755 --- a/formatscaper/formatscaper.py +++ b/formatscaper/formatscaper.py @@ -5,6 +5,7 @@ import dataclasses import re import subprocess import sys +import threading import progressbar import yaml @@ -31,7 +32,6 @@ parser.add_argument( "--formats", "--f", default="formats.yml", - nargs="?", help="list of known file formats and if they're endangered; this file will be updated (default: formats.yml)", # noqa ) parser.add_argument( @@ -46,17 +46,18 @@ parser.add_argument( default="results.yml", help="file in which to store the identified format for each file (default: results.yml)", # noqa ) +parser.add_argument( + "--parallel", + "-p", + default=1, + type=int, + help="number of siegfried processes to run in parallel (default: 1)", +) parser.add_argument( "--sf-binary", default="sf", help="name of the siegfried binary to call (default: sf)", ) -parser.add_argument( - "--sf-parallel", - default=1, - type=int, - help="number of parallel processes used by sf (default: 1)", -) parser.add_argument( "--sf-error-log", default="sf.log", @@ -64,7 +65,7 @@ parser.add_argument( ) parser.add_argument( "--no-progressbar", - "-P", + "-B", default=False, action="store_true", help="disable the progress bar", @@ -128,56 +129,83 @@ except OSError as e: sf_error_log = None -# analyze each file listed in the record files +# set up variables required in the collection of results all_results = [] endangered_files = [] -if not args.no_progressbar: - bar = progressbar.ProgressBar( - widgets=[ - progressbar.Percentage(), " (", progressbar.SimpleProgress(), ") ", progressbar.Bar(), " ", progressbar.Timer(), # noqa - ], - ) - record_files = bar(record_files or []) +sem = threading.Semaphore(args.parallel) +mutex = threading.Lock() +completed_tasks = 0 +progress_bar = progressbar.ProgressBar( + max_value=len(record_files), + widgets=[ + # fmt: off + progressbar.Percentage(), " (", progressbar.SimpleProgress(), ") ", progressbar.Bar(), " ", progressbar.Timer(), # noqa + # fmt: on + ], +) + +def process_record_file(record_file): + with sem: + sf_output = subprocess.check_output( + [ + args.sf_binary, + "-z", + "-multi", + "1", + "-name", + record_file["filename"], + record_file["uri"], + ], + stderr=sf_error_log, + ) + + # skip the sf info part + file_infos = yaml.safe_load_all(sf_output) + next(file_infos) + + # go through all the files analyzed by siegfried which can be several, + # if the original input file was an archive + for file_info in file_infos: + if not file_info.get("errors", None) and file_info.get("matches", []): + for match in file_info["matches"]: + if match["ns"] == "pronom": + format = Format( + name=match["format"], + puid=match["id"], + mime=match["mime"], + endangered=False, + ) + + # the storing of results needs to be mutually exclusive + with mutex: + format = formats.setdefault(format.puid, format) + result = Result( + filename=file_info["filename"], + record=record_file["record"], + format=format, + ) + all_results.append(result) + if formats[format.puid].endangered: + endangered_files.append(result) + + # when the task ends, update the progress bar + with mutex: + global completed_tasks + completed_tasks += 1 + if not args.no_progressbar: + progress_bar.update(completed_tasks) + + +# analyze all the files in parallel, and create the summary after all threads complete +threads = [] for record_file in record_files or []: - sf_output = subprocess.check_output( - [ - args.sf_binary, - "-z", - "-multi", - str(args.sf_parallel), - "-name", - record_file["filename"], - record_file["uri"], - ], - stderr=sf_error_log, - ) + thread = threading.Thread(target=process_record_file, args=[record_file]) + threads.append(thread) + thread.start() - # skip the sf info part - file_infos = yaml.safe_load_all(sf_output) - next(file_infos) - - # go through all the files analyzed by siegfried which can be several, - # if the original input file was an archive - for file_info in file_infos: - if not file_info.get("errors", None) and file_info.get("matches", []): - for match in file_info["matches"]: - if match["ns"] == "pronom": - format = Format( - name=match["format"], - puid=match["id"], - mime=match["mime"], - endangered=False, - ) - format = formats.setdefault(format.puid, format) - result = Result( - filename=file_info["filename"], - record=record_file["record"], - format=format, - ) - all_results.append(result) - if formats[format.puid].endangered: - endangered_files.append(result) +for thread in threads: + thread.join() if sf_error_log is not None: sf_error_log.close()