anyantudre commited on
Commit
5777e51
·
verified ·
1 Parent(s): 7a38a58

Update goai_stt.py

Browse files
Files changed (1) hide show
  1. goai_stt.py +13 -13
goai_stt.py CHANGED
@@ -5,18 +5,14 @@ from transformers import set_seed, Wav2Vec2ForCTC, AutoProcessor
5
 
6
  device = 0 if torch.cuda.is_available() else "cpu"
7
 
8
-
9
  def goai_stt(fichier):
10
  """
11
  Transcrire un fichier audio donné.
12
 
13
  Paramètres
14
  ----------
15
- fichier: str
16
- Le chemin d'accès au fichier audio.
17
-
18
- device: str
19
- GPU ou CPU
20
 
21
  Return
22
  ----------
@@ -24,9 +20,8 @@ def goai_stt(fichier):
24
  Le texte transcrit.
25
  """
26
 
27
-
28
  ### assurer reproducibilité
29
- set_seed(2024)
30
 
31
  start_time = time.time()
32
 
@@ -34,11 +29,16 @@ def goai_stt(fichier):
34
  model_id = "anyantudre/wav2vec2-large-mms-1b-mos-V1"
35
 
36
  processor = AutoProcessor.from_pretrained(model_id)
37
- model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="mos", ignore_mismatched_sizes=True)
 
 
 
 
 
 
 
38
 
39
- ### preprocessing de l'audio
40
- signal, sampling_rate = librosa.load(fichier, sr=16000)
41
- inputs = processor(signal, sampling_rate=16_000, return_tensors="pt", padding=True)
42
 
43
  ### faire l'inference
44
  with torch.no_grad():
@@ -48,4 +48,4 @@ def goai_stt(fichier):
48
  transcription = processor.decode(pred_ids)
49
 
50
  print("Temps écoulé: ", int(time.time() - start_time), " secondes")
51
- return transcription
 
5
 
6
  device = 0 if torch.cuda.is_available() else "cpu"
7
 
 
8
  def goai_stt(fichier):
9
  """
10
  Transcrire un fichier audio donné.
11
 
12
  Paramètres
13
  ----------
14
+ fichier: str | np.ndarray
15
+ Le chemin d'accès au fichier audio ou le tableau numpy.
 
 
 
16
 
17
  Return
18
  ----------
 
20
  Le texte transcrit.
21
  """
22
 
 
23
  ### assurer reproducibilité
24
+ set_seed(2024)
25
 
26
  start_time = time.time()
27
 
 
29
  model_id = "anyantudre/wav2vec2-large-mms-1b-mos-V1"
30
 
31
  processor = AutoProcessor.from_pretrained(model_id)
32
+ model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="mos", ignore_mismatched_sizes=True).to(device)
33
+
34
+ if isinstance(fichier, str):
35
+ ### preprocessing de l'audio à partir d'un fichier
36
+ signal, sampling_rate = librosa.load(fichier, sr=16000)
37
+ else:
38
+ ### preprocessing de l'audio à partir d'un tableau numpy
39
+ signal, sampling_rate = fichier
40
 
41
+ inputs = processor(signal, sampling_rate=16_000, return_tensors="pt", padding=True).to(device)
 
 
42
 
43
  ### faire l'inference
44
  with torch.no_grad():
 
48
  transcription = processor.decode(pred_ids)
49
 
50
  print("Temps écoulé: ", int(time.time() - start_time), " secondes")
51
+ return transcription