import pyclamd from picklescan.scanner import ( scan_url, scan_file_path, ScanResult, SafetyLevel ) # def scan_file(file_path: str): # ret = scan_pickle_bytes(io.BytesIO(pickle.dumps(file_path)), "file.pkl") # print(ret) def scan_file(file_path: str): if file_path.startswith("http"): scan_result: ScanResult = scan_url(file_path) else: scan_result: ScanResult = scan_file_path(file_path) globalImports = list(map(lambda x: fmt_import(x.module, x.name), scan_result.globals)) dangerousImports = list(map(lambda x: fmt_import(x.module, x.name), filter(lambda x: x.safety == SafetyLevel.Dangerous, scan_result.globals))) if len(dangerousImports) > 0: picklescanExitCode = 1 else: picklescanExitCode = 0 return { 'url': file_path, 'fileExists': True, 'picklescanExitCode': picklescanExitCode, 'picklescanGlobalImports': globalImports, 'picklescanDangerousImports': dangerousImports, # 'clamscanExitCode': ScanExitCode, # 'clamscanOutput': string, # hashes: Record < ModelHashType, string >; # conversions: Record < 'safetensors' | 'ckpt', ConversionResult >; } def init_clamd(): clamd = pyclamd.ClamdUnixSocket() return clamd def clamd_file(file_path: str, clamd): if file_path.startswith("http"): import urllib.request tmp_path = f'/tmp/clamd_{file_path.split("/")[-1].split("?")[0]}' print("tmp_path ", tmp_path) urllib.request.urlretrieve(file_path, tmp_path) ret = clamd.scan_file(tmp_path) if ret is None: return { 'clamscanExitCode': 0, 'clamscanOutput': "No virus found", } elif file_path in ret and len(file_path) > 0: return { 'clamscanExitCode': 1, 'clamscanOutput': ' '.join(ret[file_path]), } def fmt_import(module: str, name: str): return f"from ${module} import ${name}", if __name__ == "__main__": detail = scan_file("https://huggingface.co/yesyeahvh/bad-hands-5/resolve/main/bad-hands-5.pt") clamd_detail = clamd_file("https://huggingface.co/yesyeahvh/bad-hands-5/resolve/main/bad-hands-5.pt") print(detail) print(clamd_detail) # ScanResult( # globals=[Global(module='torch', name='FloatStorage', safety= < SafetyLevel.Innocuous: 'innocuous' >), # Global(module='collections', name='OrderedDict', safety= < SafetyLevel.Innocuous: 'innocuous' >), # Global(module='torch._utils', name='_rebuild_tensor_v2',safety= < SafetyLevel.Innocuous: 'innocuous' >)], # scanned_files = 1, issues_count = 0, infected_files = 0, scan_err = False)