mskov commited on
Commit
ff14337
·
1 Parent(s): 21a90b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -2
app.py CHANGED
@@ -4,8 +4,12 @@ import evaluate
4
  from evaluate.utils import launch_gradio_widget
5
  import gradio as gr
6
  import torch
 
 
7
  from speechbrain.pretrained.interfaces import foreign_class
8
- from transformers import AutoModelForSequenceClassification, pipeline, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer
 
 
9
  # pull in emotion detection
10
  # --- Add element for specification
11
  # pull in text classification
@@ -14,6 +18,8 @@ from transformers import AutoModelForSequenceClassification, pipeline, RobertaFo
14
  # add logic to initiate mock notificaiton when detected
15
  # pull in misophonia-specific model
16
 
 
 
17
  # Building prediction function for gradio
18
  emo_dict = {
19
  'sad': 'Sad',
@@ -76,8 +82,38 @@ def classify_toxicity(audio_file, text_input, classify_anxiety):
76
  return toxicity_score, classification_output, emo_dict[text_lab[0]], transcribed_text
77
  # return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
78
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  return classify_anxiety
80
-
81
  with gr.Blocks() as iface:
82
  with gr.Column():
83
  classify = gr.Radio(["racism", "LGBTQ+ hate", "sexually explicit", "misophonia"])
 
4
  from evaluate.utils import launch_gradio_widget
5
  import gradio as gr
6
  import torch
7
+ import classify
8
+ from whisper.tokenizer import get_tokenizer
9
  from speechbrain.pretrained.interfaces import foreign_class
10
+ from transformers import AutoModelForSequenceClassification, pipeline, WhisperTokenizer, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer
11
+
12
+
13
  # pull in emotion detection
14
  # --- Add element for specification
15
  # pull in text classification
 
18
  # add logic to initiate mock notificaiton when detected
19
  # pull in misophonia-specific model
20
 
21
+ model_cache = {}
22
+
23
  # Building prediction function for gradio
24
  emo_dict = {
25
  'sad': 'Sad',
 
82
  return toxicity_score, classification_output, emo_dict[text_lab[0]], transcribed_text
83
  # return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
84
  else:
85
+ # model = model_cache[model_name]
86
+ class_names = classify_anxiety.split(",")
87
+ # tokenizer = get_tokenizer(multilingual=".en" not in model_name)
88
+ tokenizer= WhisperTokenizer.from_pretrained("openai/whisper-large")
89
+ model = "large"
90
+
91
+ if model_name not in model_cache:
92
+ model = whisper.load_model(model_name)
93
+ model_cache[model_name] = model
94
+ else:
95
+ # model = model_cache[model_name]
96
+ class_names = classify_anxiety.split(",")
97
+
98
+ internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
99
+ model=model,
100
+ class_names=class_names,
101
+ class_names=classify_anxiety
102
+ tokenizer=tokenizer,
103
+ )
104
+ audio_features = classify.calculate_audio_features(audio_path, model)
105
+ average_logprobs = classify.calculate_average_logprobs(
106
+ model=model,
107
+ audio_features=audio_features,
108
+ class_names=class_names,
109
+ tokenizer=tokenizer,
110
+ )
111
+ average_logprobs -= internal_lm_average_logprobs
112
+ scores = average_logprobs.softmax(-1).tolist()
113
+ return {class_name: score for class_name, score in zip(class_names, scores)}
114
+
115
  return classify_anxiety
116
+
117
  with gr.Blocks() as iface:
118
  with gr.Column():
119
  classify = gr.Radio(["racism", "LGBTQ+ hate", "sexually explicit", "misophonia"])