File size: 4,400 Bytes
f624d68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
079300d
 
f624d68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

# found on https://stackoverflow.com/a/52130355 to fix infinite recursion with ssl
# at the beginning of the script
import gevent.monkey
gevent.monkey.patch_all()

import json
from datetime import date, datetime
import sys
import time

import huggingface_hub
import sqlite3
from tqdm import tqdm

fs = huggingface_hub.HfFileSystem()

import list_repos

SQLITE3_DB = "data/files.sqlite3"


def json_serial(obj):
    if isinstance(obj, (datetime, date)):
        return obj.isoformat()
    raise TypeError("Type %s not serializable" % type(obj))


def list_files_from_hub(repo, replace_model_in_url=True):
    # remove models/ from the front of repo,
    # since the "default" type of repo is a model.
    # the underlying implementation of fs.ls appends repo as /api/models/<repo>.
    if replace_model_in_url and repo.startswith("models/"):
        repo = repo.replace("models/", "", 1)

    # implement our own recursive list since it will make multiple requests,
    # one for each ls, which are much more likely to succeed.
    # passing recursive=True (which is undocumented anyway) does it in one request
    # which really slams the server and might give a 500 error due to hitting some
    # backend timeout.
    items = fs.ls(repo)
    for item in items:
        if item["type"] == "directory":
            yield from list_files_from_hub(item["name"], replace_model_in_url=False)
        else:
            yield item


def write_files_to_db(repo):
    print("Opening database", SQLITE3_DB, file=sys.stderr)
    con = sqlite3.connect(SQLITE3_DB)
    cur = con.cursor()
    print("Creating files table if not exists", file=sys.stderr)
    cur.execute(
        "CREATE TABLE IF NOT EXISTS files (name TEXT PRIMARY KEY, last_updated_datetime INTEGER, repo TEXT, size INTEGER, type TEXT, blob_id TEXT, is_lfs INTEGER, lfs_size INTEGER, lfs_sha256 TEXT, lfs_pointer_size INTEGER, last_commit_oid TEXT, last_commit_title TEXT, last_commit_date TEXT)"
    )
    con.commit()
    print("Deleting existing rows for repo {}".format(repo), file=sys.stderr)
    cur.execute("DELETE FROM files WHERE repo = '{}'".format(repo))
    con.commit()
    print("Inserting new rows from HFFileSystem query for repo {}".format(repo), file=sys.stderr)
    for file in tqdm(list_files_from_hub(repo)):
        is_lfs = file["lfs"] is not None
        # Something is wrong below -- occasionally see an error like
        # sqlite3.OperationalError: near "t": syntax error
        query = "INSERT INTO files VALUES ('{}', {}, '{}', {}, '{}', '{}', {}, {}, '{}', {}, '{}', '{}', '{}')".format(
            file["name"],
            int(time.time()),
            repo,
            file["size"],
            file["type"],
            file["blob_id"],
            1 if is_lfs else 0,
            file["lfs"]["size"] if is_lfs else 'NULL',
            file["lfs"]["sha256"] if is_lfs else 'NULL',
            file["lfs"]["pointer_size"] if is_lfs else 'NULL',
            file["last_commit"]["oid"],
            file["last_commit"]["title"],
            file["last_commit"]["date"],
        )
        cur.execute(query)
        con.commit()


def is_lfs(file):
    return file["lfs"] is not None


def list_lfs_files(repo):
    list = list_files(repo)
    for file in list:
        if is_lfs(file):
            yield file


def list_files(repo, limit=None):
    con = sqlite3.connect(SQLITE3_DB)
    cur = con.cursor()
    if limit is None:
        res = cur.execute("SELECT * FROM files WHERE repo == '{}'".format(repo))
    else:
        res = cur.execute("SELECT * FROM files WHERE repo == '{}' LIMIT {}".format(repo, limit))
    ret = [
        {
            "name": row[0],
            "last_updated_datetime": row[1],
            "size": row[2],
            "type": row[3],
            "blob_id": row[4],
            "lfs": (
                {"size": row[6], "sha256": row[7], "pointer_size": row[8]}
                if row[5]
                else None
            ),
            "last_commit": {"oid": row[9], "title": row[10], "date": row[11]},
        }
        for row in res.fetchall()
    ]
    return ret

if __name__ == "__main__":
    for repo in list_repos.list_repos():
        write_files_to_db(repo)
    print("Done writing to DB. Sample of 9 rows:")
    for repo in list_repos.list_repos(limit=3):
        for file in list_files(repo, limit=3):
            print(json.dumps(file, default=json_serial))