Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -18,69 +18,17 @@ app = Flask(__name__, static_url_path='/static')
|
|
18 |
|
19 |
CORS(app)
|
20 |
|
21 |
-
TOKEN = os.environ.get('dataset_token')
|
22 |
|
23 |
-
DB_FILE = Path("./prompts.db")
|
24 |
-
|
25 |
-
repo = Repository(
|
26 |
-
local_dir="data",
|
27 |
-
repo_type="dataset",
|
28 |
-
clone_from="huggingface-projects/wordalle_guesses",
|
29 |
-
use_auth_token=TOKEN
|
30 |
-
)
|
31 |
-
repo.git_pull()
|
32 |
-
# copy db on db to local path
|
33 |
-
shutil.copyfile("./data/prompts.db", DB_FILE)
|
34 |
-
|
35 |
-
dataset = load_dataset(
|
36 |
-
"huggingface-projects/wordalle_prompts",
|
37 |
-
use_auth_token=TOKEN)
|
38 |
|
39 |
Path("static/images").mkdir(parents=True, exist_ok=True)
|
40 |
|
41 |
-
db = sqlite3.connect(DB_FILE)
|
42 |
-
try:
|
43 |
-
data = db.execute("SELECT * FROM prompts").fetchall()
|
44 |
-
db.close()
|
45 |
-
except sqlite3.OperationalError:
|
46 |
-
db.execute('CREATE TABLE prompts (guess TEXT, correct TEXT)')
|
47 |
-
db.commit()
|
48 |
-
|
49 |
-
# extract images and prompts from dataset and save to dis
|
50 |
-
data = {}
|
51 |
-
for row in dataset['train']:
|
52 |
-
prompt = dataset['train'].features['label'].int2str(row['label'])
|
53 |
-
image = row['image']
|
54 |
-
hash = uuid.uuid4().hex
|
55 |
-
image_file = Path(f'static/images/{hash}.jpg')
|
56 |
-
image_compress = image.resize((136, 136), Image.Resampling.LANCZOS)
|
57 |
-
image_compress.save(image_file, optimize=True, quality=95)
|
58 |
-
if prompt not in data:
|
59 |
-
data[prompt] = []
|
60 |
-
data[prompt].append(str(image_file))
|
61 |
-
|
62 |
-
with open('static/data.json', 'w') as f:
|
63 |
-
json.dump(data, f)
|
64 |
|
65 |
|
66 |
def update_repository():
|
67 |
-
|
68 |
-
# copy db on db to local path
|
69 |
-
shutil.copyfile(DB_FILE, "./data/prompts.db")
|
70 |
-
|
71 |
-
with sqlite3.connect("./data/prompts.db") as db:
|
72 |
-
db.row_factory = sqlite3.Row
|
73 |
-
result = db.execute("SELECT * FROM prompts").fetchall()
|
74 |
-
# data = [dict(row) for row in result]
|
75 |
-
os
|
76 |
-
# with open('./data/data.json', 'w') as f:
|
77 |
-
# json.dump(data, f, separators=(',', ':'))
|
78 |
|
79 |
print("Updating repository")
|
80 |
-
|
81 |
-
"git add . && git commit --amend -m 'update' && git push --force", cwd="./data", shell=True)
|
82 |
-
# repo.push_to_hub(blocking=False)
|
83 |
-
|
84 |
|
85 |
@ app.route('/')
|
86 |
def index():
|
|
|
18 |
|
19 |
CORS(app)
|
20 |
|
|
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
Path("static/images").mkdir(parents=True, exist_ok=True)
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
def update_repository():
|
28 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
print("Updating repository")
|
31 |
+
|
|
|
|
|
|
|
32 |
|
33 |
@ app.route('/')
|
34 |
def index():
|