File size: 4,400 Bytes
f624d68 079300d f624d68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# found on https://stackoverflow.com/a/52130355 to fix infinite recursion with ssl
# at the beginning of the script
import gevent.monkey
gevent.monkey.patch_all()
import json
from datetime import date, datetime
import sys
import time
import huggingface_hub
import sqlite3
from tqdm import tqdm
fs = huggingface_hub.HfFileSystem()
import list_repos
SQLITE3_DB = "data/files.sqlite3"
def json_serial(obj):
if isinstance(obj, (datetime, date)):
return obj.isoformat()
raise TypeError("Type %s not serializable" % type(obj))
def list_files_from_hub(repo, replace_model_in_url=True):
# remove models/ from the front of repo,
# since the "default" type of repo is a model.
# the underlying implementation of fs.ls appends repo as /api/models/<repo>.
if replace_model_in_url and repo.startswith("models/"):
repo = repo.replace("models/", "", 1)
# implement our own recursive list since it will make multiple requests,
# one for each ls, which are much more likely to succeed.
# passing recursive=True (which is undocumented anyway) does it in one request
# which really slams the server and might give a 500 error due to hitting some
# backend timeout.
items = fs.ls(repo)
for item in items:
if item["type"] == "directory":
yield from list_files_from_hub(item["name"], replace_model_in_url=False)
else:
yield item
def write_files_to_db(repo):
print("Opening database", SQLITE3_DB, file=sys.stderr)
con = sqlite3.connect(SQLITE3_DB)
cur = con.cursor()
print("Creating files table if not exists", file=sys.stderr)
cur.execute(
"CREATE TABLE IF NOT EXISTS files (name TEXT PRIMARY KEY, last_updated_datetime INTEGER, repo TEXT, size INTEGER, type TEXT, blob_id TEXT, is_lfs INTEGER, lfs_size INTEGER, lfs_sha256 TEXT, lfs_pointer_size INTEGER, last_commit_oid TEXT, last_commit_title TEXT, last_commit_date TEXT)"
)
con.commit()
print("Deleting existing rows for repo {}".format(repo), file=sys.stderr)
cur.execute("DELETE FROM files WHERE repo = '{}'".format(repo))
con.commit()
print("Inserting new rows from HFFileSystem query for repo {}".format(repo), file=sys.stderr)
for file in tqdm(list_files_from_hub(repo)):
is_lfs = file["lfs"] is not None
# Something is wrong below -- occasionally see an error like
# sqlite3.OperationalError: near "t": syntax error
query = "INSERT INTO files VALUES ('{}', {}, '{}', {}, '{}', '{}', {}, {}, '{}', {}, '{}', '{}', '{}')".format(
file["name"],
int(time.time()),
repo,
file["size"],
file["type"],
file["blob_id"],
1 if is_lfs else 0,
file["lfs"]["size"] if is_lfs else 'NULL',
file["lfs"]["sha256"] if is_lfs else 'NULL',
file["lfs"]["pointer_size"] if is_lfs else 'NULL',
file["last_commit"]["oid"],
file["last_commit"]["title"],
file["last_commit"]["date"],
)
cur.execute(query)
con.commit()
def is_lfs(file):
return file["lfs"] is not None
def list_lfs_files(repo):
list = list_files(repo)
for file in list:
if is_lfs(file):
yield file
def list_files(repo, limit=None):
con = sqlite3.connect(SQLITE3_DB)
cur = con.cursor()
if limit is None:
res = cur.execute("SELECT * FROM files WHERE repo == '{}'".format(repo))
else:
res = cur.execute("SELECT * FROM files WHERE repo == '{}' LIMIT {}".format(repo, limit))
ret = [
{
"name": row[0],
"last_updated_datetime": row[1],
"size": row[2],
"type": row[3],
"blob_id": row[4],
"lfs": (
{"size": row[6], "sha256": row[7], "pointer_size": row[8]}
if row[5]
else None
),
"last_commit": {"oid": row[9], "title": row[10], "date": row[11]},
}
for row in res.fetchall()
]
return ret
if __name__ == "__main__":
for repo in list_repos.list_repos():
write_files_to_db(repo)
print("Done writing to DB. Sample of 9 rows:")
for repo in list_repos.list_repos(limit=3):
for file in list_files(repo, limit=3):
print(json.dumps(file, default=json_serial))
|