model-scan-2 / scan.py
pengdaqian
fix now
8787dd3
raw
history blame
2.8 kB
import pyclamd
from picklescan.scanner import (
scan_url,
scan_file_path,
ScanResult, SafetyLevel
)
# def scan_file(file_path: str):
# ret = scan_pickle_bytes(io.BytesIO(pickle.dumps(file_path)), "file.pkl")
# print(ret)
def scan_file(file_path: str):
if file_path.startswith("http"):
scan_result: ScanResult = scan_url(file_path)
else:
scan_result: ScanResult = scan_file_path(file_path)
globalImports = list(map(lambda x: fmt_import(x.module, x.name), scan_result.globals))
dangerousImports = list(map(lambda x: fmt_import(x.module, x.name),
filter(lambda x: x.safety == SafetyLevel.Dangerous, scan_result.globals)))
if len(dangerousImports) > 0:
picklescanExitCode = 1
else:
picklescanExitCode = 0
return {
'url': file_path,
'fileExists': True,
'picklescanExitCode': picklescanExitCode,
'picklescanGlobalImports': globalImports,
'picklescanDangerousImports': dangerousImports,
# 'clamscanExitCode': ScanExitCode,
# 'clamscanOutput': string,
# hashes: Record < ModelHashType, string >;
# conversions: Record < 'safetensors' | 'ckpt', ConversionResult >;
}
def init_clamd():
clamd = pyclamd.ClamdUnixSocket()
return clamd
def clamd_file(file_path: str, clamd):
if file_path.startswith("http"):
import urllib.request
tmp_path = f'/tmp/clamd_{file_path.split("/")[-1].split("?")[0]}'
print("tmp_path ", tmp_path)
urllib.request.urlretrieve(file_path, tmp_path)
ret = clamd.scan_file(tmp_path)
if ret is None:
return {
'clamscanExitCode': 0,
'clamscanOutput': "No virus found",
}
elif file_path in ret and len(file_path) > 0:
return {
'clamscanExitCode': 1,
'clamscanOutput': ' '.join(ret[file_path]),
}
def fmt_import(module: str, name: str):
return f"from ${module} import ${name}",
if __name__ == "__main__":
detail = scan_file("https://huggingface.co/yesyeahvh/bad-hands-5/resolve/main/bad-hands-5.pt")
clamd_detail = clamd_file("https://huggingface.co/yesyeahvh/bad-hands-5/resolve/main/bad-hands-5.pt")
print(detail)
print(clamd_detail)
# ScanResult(
# globals=[Global(module='torch', name='FloatStorage', safety= < SafetyLevel.Innocuous: 'innocuous' >),
# Global(module='collections', name='OrderedDict', safety= < SafetyLevel.Innocuous: 'innocuous' >),
# Global(module='torch._utils', name='_rebuild_tensor_v2',safety= < SafetyLevel.Innocuous: 'innocuous' >)],
# scanned_files = 1, issues_count = 0, infected_files = 0, scan_err = False)