Ved Gupta commited on
Commit
2d8b375
·
1 Parent(s): 3c3c0de

v3 download model added

Browse files
app/main.py CHANGED
@@ -5,6 +5,11 @@ from app.core.errors import error_handler
5
  from fastapi.middleware.cors import CORSMiddleware
6
 
7
  from app.utils import print_routes
 
 
 
 
 
8
 
9
  app = FastAPI(
10
  title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json"
 
5
  from fastapi.middleware.cors import CORSMiddleware
6
 
7
  from app.utils import print_routes
8
+ from app.utils.checks import run_checks
9
+
10
+ if not run_checks():
11
+ raise Exception("Failed to pass all checks")
12
+
13
 
14
  app = FastAPI(
15
  title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json"
app/utils/checks.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from app.utils.constant import model_names, model_urls
3
+ from app.utils.utils import download_file
4
+
5
+
6
+ def run_checks():
7
+ try:
8
+ if not check_models_exist():
9
+ return False
10
+ return True
11
+ except Exception as e:
12
+ print("Error in run_checks: {}".format(str(e)))
13
+ return False
14
+
15
+
16
+ def check_models_exist():
17
+ try:
18
+ for key, value in model_names.items():
19
+ if os.path.exists(os.path.join(os.getcwd(), "models", value)):
20
+ print("Model {} exists".format(key))
21
+ else:
22
+ print("Model {} does not exist".format(key))
23
+ download_model(key)
24
+ return True
25
+ except Exception as e:
26
+ print("Error in check_models_exist: {}".format(str(e)))
27
+ return False
28
+
29
+
30
+ def download_model(model_key: str):
31
+ try:
32
+ print("Downloading model {} from {}".format(model_key, model_urls[model_key]))
33
+ download_file(
34
+ model_urls[model_key],
35
+ os.path.join(os.getcwd(), "models", model_names[model_key]),
36
+ )
37
+ print("Downloaded model {} from {}".format(model_key, model_urls[model_key]))
38
+ except Exception as e:
39
+ print("Error in download_models: {}".format(str(e)))
app/utils/constant.py CHANGED
@@ -3,3 +3,9 @@ model_names = {
3
  "tiny.en.q5": "ggml-model-whisper-tiny.en-q5_1.bin",
4
  "base.en.q5": "ggml-model-whisper-base.en-q5_1.bin",
5
  }
 
 
 
 
 
 
 
3
  "tiny.en.q5": "ggml-model-whisper-tiny.en-q5_1.bin",
4
  "base.en.q5": "ggml-model-whisper-base.en-q5_1.bin",
5
  }
6
+
7
+ model_urls = {
8
+ "tiny.en": "https://firebasestorage.googleapis.com/v0/b/model-innovatorved.appspot.com/o/ggml-model-whisper-base.en-q5_1.bin?alt=media",
9
+ "tiny.en.q5": "https://firebasestorage.googleapis.com/v0/b/model-innovatorved.appspot.com/o/ggml-model-whisper-tiny.en-q5_1.bin?alt=media",
10
+ "base.en.q5": "https://firebasestorage.googleapis.com/v0/b/model-innovatorved.appspot.com/o/ggml-model-whisper-base.en-q5_1.bin?alt=media",
11
+ }
app/utils/utils.py CHANGED
@@ -1,8 +1,12 @@
 
 
1
  import subprocess
2
  import uuid
3
  import logging
4
  import wave
5
  import gdown
 
 
6
 
7
  from .constant import model_names
8
 
@@ -102,10 +106,30 @@ def get_model_name(model: str = None):
102
  return model_names["tiny.en.q5"]
103
 
104
 
105
- def download_from_drive(url , output):
106
  try:
107
  gdown.download(url, output, quiet=False)
108
  return True
109
  except:
110
  print("Error Occured in Downloading model from Gdrive")
111
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib
3
  import subprocess
4
  import uuid
5
  import logging
6
  import wave
7
  import gdown
8
+ from tqdm import tqdm
9
+
10
 
11
  from .constant import model_names
12
 
 
106
  return model_names["tiny.en.q5"]
107
 
108
 
109
+ def download_from_drive(url, output):
110
  try:
111
  gdown.download(url, output, quiet=False)
112
  return True
113
  except:
114
  print("Error Occured in Downloading model from Gdrive")
115
  return False
116
+
117
+
118
+ def download_file(url, filepath):
119
+ try:
120
+ filename = os.path.basename(url)
121
+
122
+ with tqdm(
123
+ unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=filename
124
+ ) as progress_bar:
125
+ urllib.request.urlretrieve(
126
+ url,
127
+ filepath,
128
+ reporthook=lambda block_num, block_size, total_size: progress_bar.update(
129
+ block_size
130
+ ),
131
+ )
132
+
133
+ print("File downloaded successfully!")
134
+ except Exception as e:
135
+ print(f"An error occurred: {e}")