nevreal commited on
Commit
97f1ad3
·
verified ·
1 Parent(s): aa226e1

Create easyfuncs.py

Browse files
Files changed (1) hide show
  1. easyfuncs.py +112 -0
easyfuncs.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, subprocess
2
+ import gradio as gr
3
+ import shutil
4
+ from mega import Mega
5
+
6
+ import pandas as pd
7
+ import os
8
+
9
+ # Class to handle caching model urls from a spreadsheet
10
+ class CachedModels:
11
+ def __init__(self):
12
+ csv_url = "https://docs.google.com/spreadsheets/d/1tAUaQrEHYgRsm1Lvrnj14HFHDwJWl0Bd9x0QePewNco/export?format=csv&gid=1977693859"
13
+ if os.path.exists("spreadsheet.csv"):
14
+ self.cached_data = pd.read_csv("spreadsheet.csv")
15
+ else:
16
+ self.cached_data = pd.read_csv(csv_url)
17
+ self.cached_data.to_csv("spreadsheet.csv", index=False)
18
+ # Cache model urls
19
+ self.models = {}
20
+ for _, row in self.cached_data.iterrows():
21
+ filename = row['Filename']
22
+ url = None
23
+ for value in row.values:
24
+ if isinstance(value, str) and "huggingface" in value:
25
+ url = value
26
+ break
27
+ if url:
28
+ self.models[filename] = url
29
+ # Get cached model urls
30
+ def get_models(self):
31
+ return self.models
32
+
33
+ def show(path,ext,on_error=None):
34
+ try:
35
+ return list(filter(lambda x: x.endswith(ext), os.listdir(path)))
36
+ except:
37
+ return on_error
38
+
39
+ def run_subprocess(command):
40
+ try:
41
+ subprocess.run(command, check=True)
42
+ return True, None
43
+ except Exception as e:
44
+ return False, e
45
+
46
+ def download_from_url(url=None, model=None):
47
+ if not url:
48
+ try:
49
+ url = model[f'{model}']
50
+ except:
51
+ gr.Warning("Failed")
52
+ return ''
53
+ if model == '':
54
+ try:
55
+ model = url.split('/')[-1].split('?')[0]
56
+ except:
57
+ gr.Warning('Please name the model')
58
+ return
59
+ model = model.replace('.pth', '').replace('.index', '').replace('.zip', '')
60
+ url = url.replace('/blob/main/', '/resolve/main/').strip()
61
+
62
+ for directory in ["downloads", "unzips","zip"]:
63
+ #shutil.rmtree(directory, ignore_errors=True)
64
+ os.makedirs(directory, exist_ok=True)
65
+
66
+ try:
67
+ if url.endswith('.pth'):
68
+ subprocess.run(["wget", url, "-O", f'assets/weights/{model}.pth'])
69
+ elif url.endswith('.index'):
70
+ os.makedirs(f'logs/{model}', exist_ok=True)
71
+ subprocess.run(["wget", url, "-O", f'logs/{model}/added_{model}.index'])
72
+ elif url.endswith('.zip'):
73
+ subprocess.run(["wget", url, "-O", f'downloads/{model}.zip'])
74
+ else:
75
+ if "drive.google.com" in url:
76
+ url = url.split('/')[0]
77
+ subprocess.run(["gdown", url, "--fuzzy", "-O", f'downloads/{model}'])
78
+ elif "mega.nz" in url:
79
+ Mega().download_url(url, 'downloads')
80
+ else:
81
+ subprocess.run(["wget", url, "-O", f'downloads/{model}'])
82
+
83
+ downloaded_file = next((f for f in os.listdir("downloads")), None)
84
+ if downloaded_file:
85
+ if downloaded_file.endswith(".zip"):
86
+ shutil.unpack_archive(f'downloads/{downloaded_file}', "unzips", 'zip')
87
+ for root, _, files in os.walk('unzips'):
88
+ for file in files:
89
+ file_path = os.path.join(root, file)
90
+ if file.endswith(".index"):
91
+ os.makedirs(f'logs/{model}', exist_ok=True)
92
+ shutil.copy2(file_path, f'logs/{model}')
93
+ elif file.endswith(".pth") and "G_" not in file and "D_" not in file:
94
+ shutil.copy(file_path, f'assets/weights/{model}.pth')
95
+ elif downloaded_file.endswith(".pth"):
96
+ shutil.copy(f'downloads/{downloaded_file}', f'assets/weights/{model}.pth')
97
+ elif downloaded_file.endswith(".index"):
98
+ os.makedirs(f'logs/{model}', exist_ok=True)
99
+ shutil.copy(f'downloads/{downloaded_file}', f'logs/{model}/added_{model}.index')
100
+ else:
101
+ gr.Warning("Failed to download file")
102
+ return 'Failed'
103
+
104
+ gr.Info("Done")
105
+ except Exception as e:
106
+ gr.Warning(f"There's been an error: {str(e)}")
107
+ finally:
108
+ shutil.rmtree("downloads", ignore_errors=True)
109
+ shutil.rmtree("unzips", ignore_errors=True)
110
+ shutil.rmtree("zip", ignore_errors=True)
111
+ return 'Done'
112
+