mskov's picture
Update app.py
1ae8e53
raw
history blame
2.55 kB
import evaluate
from evaluate.utils import launch_gradio_widget
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, pipeline, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer
# pull in emotion detection
# --- Add element for specification
# pull in text classification
# --- Add custom labels
# --- Associate labels with radio elements
# add logic to initiate mock notificaiton when detected
# pull in misophonia-specific model
# Create a Gradio interface with audio file and text inputs
def classify_toxicity(audio_file, text_input, classify_anxiety):
# Transcribe the audio file using Whisper ASR
if audio_file != None:
whisper_module = evaluate.load("whisper")
transcription_results = whisper_module.compute(uploaded=audio_file)
# Extract the transcribed text
transcribed_text = transcription_results["transcription"]
else:
transcribed_text = text_input
# Load the selected toxicity classification model
toxicity_module = evaluate.load("toxicity", "facebook/roberta-hate-speech-dynabench-r4-target")
#toxicity_module = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement")
toxicity_results = toxicity_module.compute(predictions=[transcribed_text])
toxicity_score = toxicity_results["toxicity"][0]
print(toxicity_score)
# Text classification
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
classifiation_model = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
sequence_to_classify = transcribed_text
candidate_labels = classify_anxiety
classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
print(classification_output)
return toxicity_score, transcribed_text
# return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
with gr.Blocks() as iface:
with gr.Column():
classify = gr.Radio(["racial identity hate", "LGBTQ+ hate", "sexually explicit", "misophonia"])
with gr.Column():
aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
text = gr.Textbox(label="Enter Text", placeholder="Enter text here...")
submit_btn = gr.Button(label="Run")
with gr.Column():
out_text = gr.Textbox()
submit_btn.click(fn=classify_toxicity, inputs=[aud_input, text, classify], outputs=out_text)
iface.launch()