File size: 5,967 Bytes
babca6f
 
dc9cf4c
cbe4d4c
 
c8e54ed
1ae8e53
ff14337
87e9ad0
ff14337
df85058
ff14337
 
 
53eb88c
 
 
 
 
 
 
28ff844
ff14337
 
df85058
a94b06f
df85058
 
 
 
 
 
61fa7d4
 
 
34bf2a6
61fa7d4
 
 
 
4b9eea9
df85058
c8e54ed
53eb88c
c8e54ed
bbd3701
2cadcf2
f5e59d1
 
 
9d36990
f5e59d1
f10b2fa
6bfef5d
395d676
 
 
 
 
73d041b
395d676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6615174
ff14337
7a481f6
187b547
bb7f792
187b547
bb7f792
 
 
7a481f6
187b547
ff14337
 
 
 
6615174
ff14337
789fd51
ff14337
 
be06195
ff14337
 
 
 
 
 
 
 
 
 
395d676
ff14337
33b1b5b
53eb88c
187b547
33b1b5b
ca7ae8f
 
335e90e
 
33b1b5b
187b547
30dbd25
c8e54ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
os.system("pip install git+https://github.com/openai/whisper.git")
import whisper
import evaluate
from evaluate.utils import launch_gradio_widget
import gradio as gr
import torch
import classify
from whisper.model import Whisper
from whisper.tokenizer import get_tokenizer
from speechbrain.pretrained.interfaces import foreign_class
from transformers import AutoModelForSequenceClassification, pipeline, WhisperTokenizer, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer


# pull in emotion detection
# --- Add element for specification
# pull in text classification
# --- Add custom labels
# --- Associate labels with radio elements
# add logic to initiate mock notificaiton when detected
# pull in misophonia-specific model

model_cache = {}

# Building prediction function for gradio
emo_dict = {
    'sad': 'Sad', 
    'hap': 'Happy',
    'ang': 'Anger',
    'neu': 'Neutral'
}

# static classes for now, but it would be best ot have the user select from multiple, and to enter their own
class_options = {
    "racism": ["racism", "hate speech", "bigotry", "racially targeted", "racially diminutive", "racial slur", "ethnic slur", "ethnic hate", "pro-white nationalism"],
    "LGBTQ+ hate": ["gay slur", "trans slur", "homophobic slur", "transphobia", "anti-LBGTQ+", "hate speech"],
    "sexually explicit": ["sexually explicit", "sexually coercive", "sexual exploitation", "vulgar", "raunchy", "sexually demeaning", "sexual violence", "victim blaming"],
    "misophonia": ["chewing", "breathing", "mouthsounds", "popping", "sneezing", "yawning", "smacking", "sniffling", "panting"]
}

pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large")

# Create a Gradio interface with audio file and text inputs
def classify_toxicity(audio_file, text_input, classify_anxiety):
    # Transcribe the audio file using Whisper ASR
    if audio_file != None:
        transcribed_text = pipe(audio_file)["text"]
        
        #### Emotion classification ####
        emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
        out_prob, score, index, text_lab = emotion_classifier.classify_file(audio_file)
    
    else:
        transcribed_text = text_input
    if classify_anxiety != "misophonia":
        #### Toxicity Classifier ####
            
        toxicity_module = evaluate.load("toxicity",  "facebook/roberta-hate-speech-dynabench-r4-target")
        #toxicity_module = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement")
    
        toxicity_results = toxicity_module.compute(predictions=[transcribed_text])
     
        toxicity_score = toxicity_results["toxicity"][0]
        print(toxicity_score)
    
        #### Text classification #####
    
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
        text_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    
        sequence_to_classify = transcribed_text
        print(classify_anxiety, class_options)
        candidate_labels = class_options.get(classify_anxiety, [])
        # classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
        classification_output = text_classifier(sequence_to_classify, candidate_labels, multi_label=True)
        print(classification_output)
    
        #### Emotion classification ####
        
        emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
        out_prob, score, index, text_lab = emotion_classifier.classify_file(audio_file)
     
        return toxicity_score, classification_output, emo_dict[text_lab[0]], transcribed_text
        # return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
    else: 
        model = whisper.load_model("large")
        # model = model_cache[model_name]
        # class_names = classify_anxiety.split(",")
        class_names_list = class_options.get(classify_anxiety, [])
        class_str = ""
        for elm in class_names_list:
            class_str += elm + ","
        #class_names = class_names_temp.split(",")
        class_names = class_str.split(",")
        print("class names ", class_names, "classify_anxiety ", classify_anxiety)
        
        # tokenizer = get_tokenizer(multilingual=".en" not in model_name)
        tokenizer= WhisperTokenizer.from_pretrained("openai/whisper-large")
    
        internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
            model=model,
            class_names=class_names,
            # class_names=classify_anxiety,
            tokenizer=tokenizer,
        )
        audio_features = classify.calculate_audio_features(audio_file, model)
        average_logprobs = classify.calculate_average_logprobs(
            model=model,
            audio_features=audio_features,
            class_names=class_names,
            tokenizer=tokenizer,
        )
        average_logprobs -= internal_lm_average_logprobs
        scores = average_logprobs.softmax(-1).tolist()
        return {class_name: score for class_name, score in zip(class_names, scores)}

        return classify_anxiety
     
with gr.Blocks() as iface:
    with gr.Column():
        anxiety_class = gr.Radio(["racism", "LGBTQ+ hate", "sexually explicit", "misophonia"])
    with gr.Column():
        aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
        text = gr.Textbox(label="Enter Text", placeholder="Enter text here...")
        submit_btn = gr.Button(label="Run")
    with gr.Column():
        out_text = gr.Textbox()
    submit_btn.click(fn=classify_toxicity, inputs=[aud_input, text, anxiety_class], outputs=out_text)

iface.launch()