# found on https://stackoverflow.com/a/52130355 to fix infinite recursion with ssl # at the beginning of the script import gevent.monkey gevent.monkey.patch_all() import json from datetime import date, datetime import sys import time import huggingface_hub import sqlite3 from tqdm import tqdm fs = huggingface_hub.HfFileSystem() import list_repos SQLITE3_DB = "data/files.sqlite3" def json_serial(obj): if isinstance(obj, (datetime, date)): return obj.isoformat() raise TypeError("Type %s not serializable" % type(obj)) def list_files_from_hub(repo, replace_model_in_url=True): # remove models/ from the front of repo, # since the "default" type of repo is a model. # the underlying implementation of fs.ls appends repo as /api/models/. if replace_model_in_url and repo.startswith("models/"): repo = repo.replace("models/", "", 1) # implement our own recursive list since it will make multiple requests, # one for each ls, which are much more likely to succeed. # passing recursive=True (which is undocumented anyway) does it in one request # which really slams the server and might give a 500 error due to hitting some # backend timeout. items = fs.ls(repo) for item in items: if item["type"] == "directory": yield from list_files_from_hub(item["name"], replace_model_in_url=False) else: yield item def write_files_to_db(repo): print("Opening database", SQLITE3_DB, file=sys.stderr) con = sqlite3.connect(SQLITE3_DB) cur = con.cursor() print("Creating files table if not exists", file=sys.stderr) cur.execute( "CREATE TABLE IF NOT EXISTS files (name TEXT PRIMARY KEY, last_updated_datetime INTEGER, repo TEXT, size INTEGER, type TEXT, blob_id TEXT, is_lfs INTEGER, lfs_size INTEGER, lfs_sha256 TEXT, lfs_pointer_size INTEGER, last_commit_oid TEXT, last_commit_title TEXT, last_commit_date TEXT)" ) con.commit() print("Deleting existing rows for repo {}".format(repo), file=sys.stderr) cur.execute("DELETE FROM files WHERE repo = '{}'".format(repo)) con.commit() print("Inserting new rows from HFFileSystem query for repo {}".format(repo), file=sys.stderr) for file in tqdm(list_files_from_hub(repo)): is_lfs = file["lfs"] is not None # Something is wrong below -- occasionally see an error like # sqlite3.OperationalError: near "t": syntax error query = "INSERT INTO files VALUES ('{}', {}, '{}', {}, '{}', '{}', {}, {}, '{}', {}, '{}', '{}', '{}')".format( file["name"], int(time.time()), repo, file["size"], file["type"], file["blob_id"], 1 if is_lfs else 0, file["lfs"]["size"] if is_lfs else 'NULL', file["lfs"]["sha256"] if is_lfs else 'NULL', file["lfs"]["pointer_size"] if is_lfs else 'NULL', file["last_commit"]["oid"], file["last_commit"]["title"], file["last_commit"]["date"], ) cur.execute(query) con.commit() def is_lfs(file): return file["lfs"] is not None def list_lfs_files(repo): list = list_files(repo) for file in list: if is_lfs(file): yield file def list_files(repo, limit=None): con = sqlite3.connect(SQLITE3_DB) cur = con.cursor() if limit is None: res = cur.execute("SELECT * FROM files WHERE repo == '{}'".format(repo)) else: res = cur.execute("SELECT * FROM files WHERE repo == '{}' LIMIT {}".format(repo, limit)) ret = [ { "name": row[0], "last_updated_datetime": row[1], "size": row[2], "type": row[3], "blob_id": row[4], "lfs": ( {"size": row[6], "sha256": row[7], "pointer_size": row[8]} if row[5] else None ), "last_commit": {"oid": row[9], "title": row[10], "date": row[11]}, } for row in res.fetchall() ] return ret if __name__ == "__main__": for repo in list_repos.list_repos(): write_files_to_db(repo) print("Done writing to DB. Sample of 9 rows:") for repo in list_repos.list_repos(limit=3): for file in list_files(repo, limit=3): print(json.dumps(file, default=json_serial))