Spaces:
Runtime error
Runtime error
Update goai_stt.py
Browse files- goai_stt.py +13 -13
goai_stt.py
CHANGED
@@ -5,18 +5,14 @@ from transformers import set_seed, Wav2Vec2ForCTC, AutoProcessor
|
|
5 |
|
6 |
device = 0 if torch.cuda.is_available() else "cpu"
|
7 |
|
8 |
-
|
9 |
def goai_stt(fichier):
|
10 |
"""
|
11 |
Transcrire un fichier audio donné.
|
12 |
|
13 |
Paramètres
|
14 |
----------
|
15 |
-
fichier: str
|
16 |
-
Le chemin d'accès au fichier audio.
|
17 |
-
|
18 |
-
device: str
|
19 |
-
GPU ou CPU
|
20 |
|
21 |
Return
|
22 |
----------
|
@@ -24,9 +20,8 @@ def goai_stt(fichier):
|
|
24 |
Le texte transcrit.
|
25 |
"""
|
26 |
|
27 |
-
|
28 |
### assurer reproducibilité
|
29 |
-
set_seed(2024)
|
30 |
|
31 |
start_time = time.time()
|
32 |
|
@@ -34,11 +29,16 @@ def goai_stt(fichier):
|
|
34 |
model_id = "anyantudre/wav2vec2-large-mms-1b-mos-V1"
|
35 |
|
36 |
processor = AutoProcessor.from_pretrained(model_id)
|
37 |
-
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="mos", ignore_mismatched_sizes=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
signal, sampling_rate = librosa.load(fichier, sr=16000)
|
41 |
-
inputs = processor(signal, sampling_rate=16_000, return_tensors="pt", padding=True)
|
42 |
|
43 |
### faire l'inference
|
44 |
with torch.no_grad():
|
@@ -48,4 +48,4 @@ def goai_stt(fichier):
|
|
48 |
transcription = processor.decode(pred_ids)
|
49 |
|
50 |
print("Temps écoulé: ", int(time.time() - start_time), " secondes")
|
51 |
-
return transcription
|
|
|
5 |
|
6 |
device = 0 if torch.cuda.is_available() else "cpu"
|
7 |
|
|
|
8 |
def goai_stt(fichier):
|
9 |
"""
|
10 |
Transcrire un fichier audio donné.
|
11 |
|
12 |
Paramètres
|
13 |
----------
|
14 |
+
fichier: str | np.ndarray
|
15 |
+
Le chemin d'accès au fichier audio ou le tableau numpy.
|
|
|
|
|
|
|
16 |
|
17 |
Return
|
18 |
----------
|
|
|
20 |
Le texte transcrit.
|
21 |
"""
|
22 |
|
|
|
23 |
### assurer reproducibilité
|
24 |
+
set_seed(2024)
|
25 |
|
26 |
start_time = time.time()
|
27 |
|
|
|
29 |
model_id = "anyantudre/wav2vec2-large-mms-1b-mos-V1"
|
30 |
|
31 |
processor = AutoProcessor.from_pretrained(model_id)
|
32 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="mos", ignore_mismatched_sizes=True).to(device)
|
33 |
+
|
34 |
+
if isinstance(fichier, str):
|
35 |
+
### preprocessing de l'audio à partir d'un fichier
|
36 |
+
signal, sampling_rate = librosa.load(fichier, sr=16000)
|
37 |
+
else:
|
38 |
+
### preprocessing de l'audio à partir d'un tableau numpy
|
39 |
+
signal, sampling_rate = fichier
|
40 |
|
41 |
+
inputs = processor(signal, sampling_rate=16_000, return_tensors="pt", padding=True).to(device)
|
|
|
|
|
42 |
|
43 |
### faire l'inference
|
44 |
with torch.no_grad():
|
|
|
48 |
transcription = processor.decode(pred_ids)
|
49 |
|
50 |
print("Temps écoulé: ", int(time.time() - start_time), " secondes")
|
51 |
+
return transcription
|