OpenSound commited on
Commit
9c238e8
·
verified ·
1 Parent(s): fc6da2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -73
app.py CHANGED
@@ -26,51 +26,50 @@ import spaces
26
  import nltk
27
  nltk.download('punkt')
28
 
29
- ```
30
  DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
31
  TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
32
  MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
33
  os.makedirs(MODELS_PATH, exist_ok=True)
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
 
36
- if not os.path.exists(os.path.join(MODELS_PATH, "wmencodec.th")):
37
- # download wmencodec
38
- url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th"
39
- filename = os.path.join(MODELS_PATH, "wmencodec.th")
40
- response = requests.get(url, stream=True)
41
- response.raise_for_status()
42
- with open(filename, "wb") as file:
43
- for chunk in response.iter_content(chunk_size=8192):
44
- file.write(chunk)
45
- print(f"File downloaded to: {filename}")
46
- else:
47
- print("wmencodec model found")
48
-
49
- if not os.path.exists(os.path.join(MODELS_PATH, "English.pth")):
50
- # download english model
51
- url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/English.pth"
52
- filename = os.path.join(MODELS_PATH, "English.pth")
53
- response = requests.get(url, stream=True)
54
- response.raise_for_status()
55
- with open(filename, "wb") as file:
56
- for chunk in response.iter_content(chunk_size=8192):
57
- file.write(chunk)
58
- print(f"File downloaded to: {filename}")
59
- else:
60
- print("english model found")
61
-
62
- if not os.path.exists(os.path.join(MODELS_PATH, "Mandarin.pth")):
63
- # download mandarin model
64
- url = "https://huggingface.co/westbrook/SSR-Speech-Mandarin/resolve/main/Mandarin.pth"
65
- filename = os.path.join(MODELS_PATH, "Mandarin.pth")
66
- response = requests.get(url, stream=True)
67
- response.raise_for_status()
68
- with open(filename, "wb") as file:
69
- for chunk in response.iter_content(chunk_size=8192):
70
- file.write(chunk)
71
- print(f"File downloaded to: {filename}")
72
- else:
73
- print("mandarin model found")
74
 
75
  def get_random_string():
76
  return "".join(str(uuid.uuid4()).split("-"))
@@ -132,40 +131,39 @@ from whisperx import align as align_func
132
  text_tokenizer_en = TextTokenizer(backend="espeak")
133
  text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn')
134
 
135
- ssrspeech_fn_en = f"{MODELS_PATH}/English.pth"
136
- ckpt_en = torch.load(ssrspeech_fn_en)
137
- model_en = ssr.SSR_Speech(ckpt_en["config"])
138
- model_en.load_state_dict(ckpt_en["model"])
139
- config_en = model_en.args
140
- phn2num_en = ckpt_en["phn2num"]
141
- model_en.to(device)
142
-
143
- ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth"
144
- ckpt_zh = torch.load(ssrspeech_fn_zh)
145
- model_zh = ssr.SSR_Speech(ckpt_zh["config"])
146
- model_zh.load_state_dict(ckpt_zh["model"])
147
- config_zh = model_zh.args
148
- phn2num_zh = ckpt_zh["phn2num"]
149
- model_zh.to(device)
150
-
151
- encodec_fn = f"{MODELS_PATH}/wmencodec.th"
152
-
153
- ssrspeech_model_en = {
154
- "config": config_en,
155
- "phn2num": phn2num_en,
156
- "model": model_en,
157
- "text_tokenizer": text_tokenizer_en,
158
- "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
159
- }
160
-
161
- ssrspeech_model_zh = {
162
- "config": config_zh,
163
- "phn2num": phn2num_zh,
164
- "model": model_zh,
165
- "text_tokenizer": text_tokenizer_zh,
166
- "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
167
- }
168
- ```
169
 
170
 
171
  def get_transcribe_state(segments):
 
26
  import nltk
27
  nltk.download('punkt')
28
 
 
29
  DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
30
  TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
31
  MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
32
  os.makedirs(MODELS_PATH, exist_ok=True)
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
+ # if not os.path.exists(os.path.join(MODELS_PATH, "wmencodec.th")):
36
+ # # download wmencodec
37
+ # url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th"
38
+ # filename = os.path.join(MODELS_PATH, "wmencodec.th")
39
+ # response = requests.get(url, stream=True)
40
+ # response.raise_for_status()
41
+ # with open(filename, "wb") as file:
42
+ # for chunk in response.iter_content(chunk_size=8192):
43
+ # file.write(chunk)
44
+ # print(f"File downloaded to: {filename}")
45
+ # else:
46
+ # print("wmencodec model found")
47
+
48
+ # if not os.path.exists(os.path.join(MODELS_PATH, "English.pth")):
49
+ # # download english model
50
+ # url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/English.pth"
51
+ # filename = os.path.join(MODELS_PATH, "English.pth")
52
+ # response = requests.get(url, stream=True)
53
+ # response.raise_for_status()
54
+ # with open(filename, "wb") as file:
55
+ # for chunk in response.iter_content(chunk_size=8192):
56
+ # file.write(chunk)
57
+ # print(f"File downloaded to: {filename}")
58
+ # else:
59
+ # print("english model found")
60
+
61
+ # if not os.path.exists(os.path.join(MODELS_PATH, "Mandarin.pth")):
62
+ # # download mandarin model
63
+ # url = "https://huggingface.co/westbrook/SSR-Speech-Mandarin/resolve/main/Mandarin.pth"
64
+ # filename = os.path.join(MODELS_PATH, "Mandarin.pth")
65
+ # response = requests.get(url, stream=True)
66
+ # response.raise_for_status()
67
+ # with open(filename, "wb") as file:
68
+ # for chunk in response.iter_content(chunk_size=8192):
69
+ # file.write(chunk)
70
+ # print(f"File downloaded to: {filename}")
71
+ # else:
72
+ # print("mandarin model found")
73
 
74
  def get_random_string():
75
  return "".join(str(uuid.uuid4()).split("-"))
 
131
  text_tokenizer_en = TextTokenizer(backend="espeak")
132
  text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn')
133
 
134
+ # ssrspeech_fn_en = f"{MODELS_PATH}/English.pth"
135
+ # ckpt_en = torch.load(ssrspeech_fn_en)
136
+ # model_en = ssr.SSR_Speech(ckpt_en["config"])
137
+ # model_en.load_state_dict(ckpt_en["model"])
138
+ # config_en = model_en.args
139
+ # phn2num_en = ckpt_en["phn2num"]
140
+ # model_en.to(device)
141
+
142
+ # ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth"
143
+ # ckpt_zh = torch.load(ssrspeech_fn_zh)
144
+ # model_zh = ssr.SSR_Speech(ckpt_zh["config"])
145
+ # model_zh.load_state_dict(ckpt_zh["model"])
146
+ # config_zh = model_zh.args
147
+ # phn2num_zh = ckpt_zh["phn2num"]
148
+ # model_zh.to(device)
149
+
150
+ # encodec_fn = f"{MODELS_PATH}/wmencodec.th"
151
+
152
+ # ssrspeech_model_en = {
153
+ # "config": config_en,
154
+ # "phn2num": phn2num_en,
155
+ # "model": model_en,
156
+ # "text_tokenizer": text_tokenizer_en,
157
+ # "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
158
+ # }
159
+
160
+ # ssrspeech_model_zh = {
161
+ # "config": config_zh,
162
+ # "phn2num": phn2num_zh,
163
+ # "model": model_zh,
164
+ # "text_tokenizer": text_tokenizer_zh,
165
+ # "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
166
+ # }
 
167
 
168
 
169
  def get_transcribe_state(segments):