|
import math |
|
from base64 import b32encode, b32decode |
|
from pybase64 import urlsafe_b64encode, urlsafe_b64decode |
|
from loguru import logger as log |
|
import os |
|
import time |
|
from pathlib import Path |
|
from urllib.request import urlretrieve |
|
from blake3 import blake3 |
|
from platformdirs import PlatformDirs |
|
from typing import List, Tuple |
|
from iscc_sct.models import Metadata, Feature |
|
|
|
|
|
APP_NAME = "iscc-sct" |
|
APP_AUTHOR = "iscc" |
|
dirs = PlatformDirs(appname=APP_NAME, appauthor=APP_AUTHOR) |
|
os.makedirs(dirs.user_data_dir, exist_ok=True) |
|
|
|
|
|
__all__ = [ |
|
"timer", |
|
"get_model", |
|
"encode_base32", |
|
"encode_base64", |
|
"decode_base32", |
|
"decode_base64", |
|
"hamming_distance", |
|
"iscc_distance", |
|
"cosine_similarity", |
|
"granular_similarity", |
|
"MODEL_PATH", |
|
] |
|
|
|
|
|
BASE_VERSION = "1.0.0" |
|
BASE_URL = f"https://github.com/iscc/iscc-binaries/releases/download/v{BASE_VERSION}" |
|
MODEL_FILENAME = "iscc-sct-v0.1.0.onnx" |
|
MODEL_URL = f"{BASE_URL}/{MODEL_FILENAME}" |
|
MODEL_PATH = Path(dirs.user_data_dir) / MODEL_FILENAME |
|
MODEL_CHECKSUM = "ff254d62db55ed88a1451b323a66416f60838dd2f0338dba21bc3b8822459abc" |
|
|
|
|
|
class timer: |
|
def __init__(self, message: str): |
|
self.message = message |
|
|
|
def __enter__(self): |
|
|
|
self.start_time = time.perf_counter() |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
|
|
elapsed_time = time.perf_counter() - self.start_time |
|
|
|
log.debug(f"{self.message} {elapsed_time:.4f} seconds") |
|
|
|
|
|
def get_model(): |
|
"""Check and return local model file if it exists, otherwise download.""" |
|
if MODEL_PATH.exists(): |
|
try: |
|
return check_integrity(MODEL_PATH, MODEL_CHECKSUM) |
|
except RuntimeError: |
|
log.warning("Model file integrity error - redownloading ...") |
|
urlretrieve(MODEL_URL, filename=MODEL_PATH) |
|
else: |
|
log.info("Downloading embedding model ...") |
|
urlretrieve(MODEL_URL, filename=MODEL_PATH) |
|
return check_integrity(MODEL_PATH, MODEL_CHECKSUM) |
|
|
|
|
|
def check_integrity(file_path, checksum): |
|
|
|
""" |
|
Check file integrity against blake3 checksum |
|
|
|
:param file_path: path to file to be checked |
|
:param checksum: blake3 checksum to verify integrity |
|
:raises RuntimeError: if verification fails |
|
""" |
|
file_path = Path(file_path) |
|
file_hasher = blake3(max_threads=blake3.AUTO) |
|
with timer("INTEGRITY check time"): |
|
file_hasher.update_mmap(file_path) |
|
file_hash = file_hasher.hexdigest() |
|
if checksum != file_hash: |
|
msg = f"Failed integrity check for {file_path.name}" |
|
log.error(msg) |
|
raise RuntimeError(msg) |
|
return file_path |
|
|
|
|
|
def encode_base32(data): |
|
|
|
""" |
|
Standard RFC4648 base32 encoding without padding. |
|
|
|
:param bytes data: Data for base32 encoding |
|
:return: Base32 encoded str |
|
""" |
|
return b32encode(data).decode("ascii").rstrip("=") |
|
|
|
|
|
def decode_base32(code): |
|
|
|
""" |
|
Standard RFC4648 base32 decoding without padding and with casefolding. |
|
""" |
|
|
|
cl = len(code) |
|
pad_length = math.ceil(cl / 8) * 8 - cl |
|
|
|
return bytes(b32decode(code + "=" * pad_length, casefold=True)) |
|
|
|
|
|
def encode_base64(data): |
|
|
|
""" |
|
Standard RFC4648 base64url encoding without padding. |
|
""" |
|
code = urlsafe_b64encode(data).decode("ascii") |
|
return code.rstrip("=") |
|
|
|
|
|
def decode_base64(code): |
|
|
|
""" |
|
Standard RFC4648 base64url decoding without padding. |
|
""" |
|
padding = 4 - (len(code) % 4) |
|
string = code + ("=" * padding) |
|
return urlsafe_b64decode(string) |
|
|
|
|
|
def hamming_distance(a, b): |
|
|
|
""" |
|
Calculate the bitwise Hamming distance between two bytes objects. |
|
|
|
:param a: The first bytes object. |
|
:param b: The second bytes object. |
|
:return: The Hamming distance between two bytes objects. |
|
:raise ValueError: If a and b are not the same length. |
|
""" |
|
if len(a) != len(b): |
|
raise ValueError("The lengths of the two bytes objects must be the same") |
|
|
|
distance = 0 |
|
for b1, b2 in zip(a, b): |
|
xor_result = b1 ^ b2 |
|
distance += bin(xor_result).count("1") |
|
|
|
return distance |
|
|
|
|
|
def iscc_distance(iscc1, iscc2): |
|
|
|
""" |
|
Calculate the Hamming distance between two ISCC Semantic Text Codes. |
|
|
|
:param iscc1: The first ISCC Semantic Text Code. |
|
:param iscc2: The second ISCC Semantic Text Code. |
|
:return: The Hamming distance between the two ISCC codes. |
|
:raise ValueError: If the input ISCCs are not valid or of different lengths. |
|
""" |
|
|
|
iscc1 = iscc1[5:] if iscc1.startswith("ISCC:") else iscc1 |
|
iscc2 = iscc2[5:] if iscc2.startswith("ISCC:") else iscc2 |
|
|
|
|
|
decoded1 = decode_base32(iscc1) |
|
decoded2 = decode_base32(iscc2) |
|
|
|
|
|
if len(decoded1) != len(decoded2): |
|
raise ValueError("The input ISCCs must have the same length") |
|
|
|
|
|
content1 = decoded1[2:] |
|
content2 = decoded2[2:] |
|
|
|
|
|
return hamming_distance(content1, content2) |
|
|
|
|
|
def cosine_similarity(a, b): |
|
|
|
""" |
|
Calculate the approximate cosine similarity based on Hamming distance for two bytes inputs. |
|
|
|
:param a: The first bytes object. |
|
:param b: The second bytes object. |
|
:return: The approximate cosine similarity between the two inputs, scaled from -100 to +100. |
|
:raise ValueError: If a and b are not the same length. |
|
""" |
|
if len(a) != len(b): |
|
raise ValueError("The lengths of the two bytes objects must be the same") |
|
|
|
distance = hamming_distance(a, b) |
|
total_bits = len(a) * 8 |
|
similarity = 1 - (2 * distance / total_bits) |
|
return max(min(int(similarity * 100), 100), -100) |
|
|
|
|
|
def granular_similarity(metadata_a, metadata_b, threshold=80): |
|
|
|
""" |
|
Compare simprints from two Metadata objects and return matching pairs above a similarity |
|
threshold. Only the most similar pair for each simprint_a is included. |
|
|
|
:param metadata_a: The first Metadata object. |
|
:param metadata_b: The second Metadata object. |
|
:param threshold: The similarity threshold (0-100) above which simprints are considered a match. |
|
:return: A list of tuples containing matching simprints and their similarity. |
|
""" |
|
metadata_a = metadata_a.to_object_format() |
|
metadata_b = metadata_b.to_object_format() |
|
|
|
matches = [] |
|
|
|
for feature_set_a in metadata_a.features: |
|
for simprint_a in feature_set_a.simprints: |
|
best_match = None |
|
best_similarity = threshold - 1 |
|
|
|
for feature_set_b in metadata_b.features: |
|
for simprint_b in feature_set_b.simprints: |
|
similarity = cosine_similarity( |
|
decode_base64(simprint_a.simprint), decode_base64(simprint_b.simprint) |
|
) |
|
if similarity > best_similarity: |
|
best_similarity = similarity |
|
best_match = (simprint_a, similarity, simprint_b) |
|
|
|
if best_match: |
|
matches.append(best_match) |
|
|
|
return matches |
|
|