File size: 4,229 Bytes
babca6f
 
cbe4d4c
 
c8e54ed
1ae8e53
df85058
585a1e8
53eb88c
 
 
 
 
 
 
28ff844
df85058
 
 
 
 
 
 
 
2cadcf2
df85058
c8e54ed
53eb88c
c8e54ed
bbd3701
babca6f
 
 
9e0c17e
babca6f
 
 
 
 
f10b2fa
babca6f
9e0c17e
2cadcf2
 
 
f5e59d1
 
 
 
 
f10b2fa
6bfef5d
c8e54ed
73d041b
 
e95ab8a
b65fb2a
1ff03d5
c8e54ed
 
 
1ff03d5
53eb88c
73d041b
53eb88c
 
 
73d041b
53eb88c
 
 
73d041b
 
53eb88c
 
73d041b
 
df85058
8dbe0c3
df85058
 
2724e1c
c8e54ed
33b1b5b
53eb88c
 
33b1b5b
ca7ae8f
 
335e90e
 
33b1b5b
53eb88c
30dbd25
c8e54ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
os.system("pip install git+https://github.com/openai/whisper.git")
import evaluate
from evaluate.utils import launch_gradio_widget
import gradio as gr
import torch
from speechbrain.pretrained.interfaces import foreign_class
from transformers import AutoModelForSequenceClassification, pipeline, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer
# pull in emotion detection
# --- Add element for specification
# pull in text classification
# --- Add custom labels
# --- Associate labels with radio elements
# add logic to initiate mock notificaiton when detected
# pull in misophonia-specific model

# Building prediction function for gradio
emotion_dict = {
    'sad': 'Sad', 
    'hap': 'Happy',
    'ang': 'Anger',
    'neu': 'Neutral'
}

pipe = pipeline("automatic-speech-recognition")

# Create a Gradio interface with audio file and text inputs
def classify_toxicity(audio_file, text_input, classify_anxiety):
    # Transcribe the audio file using Whisper ASR
    if audio_file != None:
        '''whisper_model = WhisperModel.from_pretrained("openai/whisper-base")
        feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
        transcription_results = whisper_model.compute(uploaded=audio_file)
        
        audio = whisper.load_audio(audio_file)
        mel = whisper.log_mel_spectrogram(audio).to(model.device)
        _, probs = model.detect_language(mel)    
        options = whisper.DecodingOptions(fp16 = False)
        result = whisper.decode(model, mel, options)
        # Extract the transcribed text
        # transcribed_text = transcription_results["transcription"]
        '''
        # model = whisper.load_model("base")
        # transcribed_text = model.transcribe(audio_file)
        transcribed_text = pipe(audio_file)["text"]
        
        #### Emotion classification ####
        emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
        out_prob, score, index, text_lab = emotion_classifier.classify_file(audio_file.name)
    
    else:
        transcribed_text = text_input
 
    #### Toxicity Classifier ####
        
    toxicity_module = evaluate.load("toxicity",  "facebook/roberta-hate-speech-dynabench-r4-target")
    #toxicity_module = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement")

    toxicity_results = toxicity_module.compute(predictions=[transcribed_text])
 
    toxicity_score = toxicity_results["toxicity"][0]
    print(toxicity_score)

    #### Text classification #####

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    text_classifier = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
    
    sequence_to_classify = transcribed_text
    candidate_labels = classify_anxiety
    # classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
    classification_output = text_classifier(sequence_to_classify, candidate_labels, multi_label=False)
    print(classification_output)

    #### Emotion classification ####
    
    emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
    out_prob, score, index, text_lab = emotion_classifier.classify_file(audio_file.name)
 
    return toxicity_score, classification_output, emo_dict[text_lab[0]], transcribed_text
    # return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
 
with gr.Blocks() as iface:
    with gr.Column():
        classify = gr.Radio(["racial identity hate", "LGBTQ+ hate", "sexually explicit", "misophonia"])
    with gr.Column():
        aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
        text = gr.Textbox(label="Enter Text", placeholder="Enter text here...")
        submit_btn = gr.Button(label="Run")
    with gr.Column():
        out_text = gr.Textbox()
    submit_btn.click(fn=classify_toxicity, inputs=[aud_input, text, classify], outputs=out_text)

iface.launch()