mskov commited on
Commit
9e0c17e
·
1 Parent(s): 3161c8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -30,7 +30,7 @@ def classify_toxicity(audio_file, text_input, classify_anxiety):
30
  '''whisper_model = WhisperModel.from_pretrained("openai/whisper-base")
31
  feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
32
  transcription_results = whisper_model.compute(uploaded=audio_file)
33
- '''
34
  audio = whisper.load_audio(audio_file)
35
  mel = whisper.log_mel_spectrogram(audio).to(model.device)
36
  _, probs = model.detect_language(mel)
@@ -38,7 +38,9 @@ def classify_toxicity(audio_file, text_input, classify_anxiety):
38
  result = whisper.decode(model, mel, options)
39
  # Extract the transcribed text
40
  # transcribed_text = transcription_results["transcription"]
41
- transcribed_text = resut.text
 
 
42
 
43
  #### Emotion classification ####
44
  emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
 
30
  '''whisper_model = WhisperModel.from_pretrained("openai/whisper-base")
31
  feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
32
  transcription_results = whisper_model.compute(uploaded=audio_file)
33
+
34
  audio = whisper.load_audio(audio_file)
35
  mel = whisper.log_mel_spectrogram(audio).to(model.device)
36
  _, probs = model.detect_language(mel)
 
38
  result = whisper.decode(model, mel, options)
39
  # Extract the transcribed text
40
  # transcribed_text = transcription_results["transcription"]
41
+ '''
42
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small", use_auth_token=huggingface_token)
43
+ transcribed_text = model.transcribe(audio_file)
44
 
45
  #### Emotion classification ####
46
  emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")