From 59a48a750985d8a425fee2f293baec4b656a4f68 Mon Sep 17 00:00:00 2001
From: Maximilian Moser <maximilian.moser@tuwien.ac.at>
Date: Mon, 19 Feb 2024 16:41:55 +0100
Subject: [PATCH] Allow specification of known results for files

* this includes marking specific files as "safe"
* it can also be used to override the file format otherwise identified
  by siegfried
---
 formatscaper/core/models.py  |  9 +++++++--
 formatscaper/core/utils.py   | 17 +++++++++++++----
 formatscaper/formatscaper.py | 33 +++++++++++++++++++++++++++++++--
 3 files changed, 51 insertions(+), 8 deletions(-)

diff --git a/formatscaper/core/models.py b/formatscaper/core/models.py
index 85d68d3..28a54f4 100644
--- a/formatscaper/core/models.py
+++ b/formatscaper/core/models.py
@@ -43,13 +43,18 @@ class Result:
     """The format identification result for a given file."""
 
     filename: str
-    record: str
     format: Format
+    record: str = None
+    safe: bool = False
 
     def as_dict(self):
         """Dump the data as dictionary."""
-        return {
+        result = {
             "filename": self.filename,
             "record": self.record,
             "format": self.format.as_dict(),
         }
+        if self.safe:
+            result["safe"] = self.safe
+
+        return result
diff --git a/formatscaper/core/utils.py b/formatscaper/core/utils.py
index b565097..44b0f33 100644
--- a/formatscaper/core/utils.py
+++ b/formatscaper/core/utils.py
@@ -88,11 +88,13 @@ def store_results(results: List[Result], file_name: str, file_format: str) -> bo
 
 
 def load_results(
-    file_name: str, file_format: Optional[str] = None
+    file_name: str, file_format: Optional[str] = None, strict: bool = True
 ) -> Optional[List[Result]]:
     """Load the results from the given file.
 
     In case the ``file_format`` isn't specified, auto-detection is attempted.
+    If ``strict`` is set, then the result loading will fail if it cannot parse
+    the format for a result.
     """
     if file_format is None:
         if re.search(r"\.ya?ml$", file_name, re.IGNORECASE):
@@ -115,9 +117,16 @@ def load_results(
     results = []
     known_formats = {}
     for res in raw_results:
-        format = known_formats.setdefault(
-            res["format"]["puid"], Format(**res["format"])
-        )
+        format = None
+        try:
+            f = Format(**res["format"])
+            format = known_formats.setdefault(res["format"]["puid"], f)
+        except (TypeError, KeyError) as e:
+            # TypeError: the result doesn't have all required parts for Format()
+            # KeyError:  either the result doesn't have a format or it lacks the PUID
+            if strict:
+                raise e
+
         res.pop("format", None)
         results.append(Result(**res, format=format))
 
diff --git a/formatscaper/formatscaper.py b/formatscaper/formatscaper.py
index f82533a..2b05b33 100755
--- a/formatscaper/formatscaper.py
+++ b/formatscaper/formatscaper.py
@@ -11,7 +11,13 @@ import progressbar as pb
 import yaml
 
 from core.models import Format, RecordFile, Result
-from core.utils import load_formats, load_record_files, store_formats, store_results
+from core.utils import (
+    load_formats,
+    load_record_files,
+    load_results,
+    store_formats,
+    store_results,
+)
 
 # set up the argument parser
 parser = argparse.ArgumentParser(
@@ -51,6 +57,12 @@ parser.add_argument(
     choices=["pickle", "yaml"],
     help="format of the results (default: pickle)",
 )
+parser.add_argument(
+    "--known-results",
+    "-k",
+    default=None,
+    help="file with known results for overriding the identification results",
+)
 parser.add_argument(
     "--parallel",
     "-p",
@@ -106,6 +118,13 @@ formats = load_formats(args.formats)
 # read the list of files to analyze
 record_files = load_record_files(args.input)
 
+# read list of known file results
+known_results = {}
+if args.known_results is not None:
+    known_results = {
+        res.filename: res for res in load_results(args.known_results, strict=False)
+    }
+
 # try to redirect the error logs from siegfried
 try:
     sf_error_log = open(args.sf_error_log, "w")
@@ -193,8 +212,18 @@ def process_record_file(record_file: RecordFile) -> None:
                                     record=record_file.record,
                                     format=format,
                                 )
+
+                                # let's check if we claim to know better than siegfried
+                                if result.filename in known_results:
+                                    known_res = known_results[result.filename]
+                                    result.safe = known_res.safe
+                                    if known_res.format is not None:
+                                        result.format = formats.get(
+                                            known_res.format.puid, known_res.format
+                                        )
+
                                 all_results.append(result)
-                                if formats[format.puid].endangered:
+                                if formats[format.puid].endangered and not result.safe:
                                     endangered_files.append(result)
 
             # when the task ends, update the progress bar
-- 
GitLab