Spaces:
Runtime error
Runtime error
File size: 2,553 Bytes
cbe4d4c c8e54ed 1ae8e53 c8e54ed 53eb88c 28ff844 c8e54ed 53eb88c c8e54ed bbd3701 f10b2fa 6bfef5d c8e54ed e95ab8a b65fb2a 1ff03d5 c8e54ed 1ff03d5 53eb88c 8cf8567 2724e1c c8e54ed 33b1b5b 53eb88c 33b1b5b ca7ae8f 335e90e 33b1b5b 53eb88c 30dbd25 c8e54ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import evaluate
from evaluate.utils import launch_gradio_widget
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, pipeline, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer
# pull in emotion detection
# --- Add element for specification
# pull in text classification
# --- Add custom labels
# --- Associate labels with radio elements
# add logic to initiate mock notificaiton when detected
# pull in misophonia-specific model
# Create a Gradio interface with audio file and text inputs
def classify_toxicity(audio_file, text_input, classify_anxiety):
# Transcribe the audio file using Whisper ASR
if audio_file != None:
whisper_module = evaluate.load("whisper")
transcription_results = whisper_module.compute(uploaded=audio_file)
# Extract the transcribed text
transcribed_text = transcription_results["transcription"]
else:
transcribed_text = text_input
# Load the selected toxicity classification model
toxicity_module = evaluate.load("toxicity", "facebook/roberta-hate-speech-dynabench-r4-target")
#toxicity_module = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement")
toxicity_results = toxicity_module.compute(predictions=[transcribed_text])
toxicity_score = toxicity_results["toxicity"][0]
print(toxicity_score)
# Text classification
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
classifiation_model = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
sequence_to_classify = transcribed_text
candidate_labels = classify_anxiety
classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
print(classification_output)
return toxicity_score, transcribed_text
# return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
with gr.Blocks() as iface:
with gr.Column():
classify = gr.Radio(["racial identity hate", "LGBTQ+ hate", "sexually explicit", "misophonia"])
with gr.Column():
aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
text = gr.Textbox(label="Enter Text", placeholder="Enter text here...")
submit_btn = gr.Button(label="Run")
with gr.Column():
out_text = gr.Textbox()
submit_btn.click(fn=classify_toxicity, inputs=[aud_input, text, classify], outputs=out_text)
iface.launch() |