Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from matplotlib import cm | |
import torch | |
from transformers import AutoTokenizer, AutoModel, AutoConfig | |
from model import Classifier | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# Load model directly | |
MODEL_NAME = "cahya/roberta-base-indonesian-522M" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
class_names = ['Action', 'Adventure', 'Comedy', 'Drama', 'Fantasy', 'Romance', 'Sci-Fi'] | |
config = AutoConfig.from_pretrained(MODEL_NAME) | |
transformer = AutoModel.from_pretrained(MODEL_NAME) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
cp = torch.load(r"weight.pt", map_location="cpu") | |
transformer.load_state_dict(cp['w_t']) | |
classifier = Classifier(input_size = config.hidden_size, output_sizes = [1, 1, 1, 3, 5]) | |
classifier.load_state_dict(cp['w_c']) | |
transformer.to(device) | |
classifier.to(device) | |
target_names = ["Individual", 'Group'] | |
strength_names = ["Weak", 'Moderate', 'Strong'] | |
type_names = ['Religion','Race','Physical','Gender','Other'] | |
act_sig = nn.Sigmoid() | |
act_soft = nn.Softmax() | |
def predict(sentence): | |
# Tokenize the input sentence | |
inputs = tokenizer(sentence, | |
add_special_tokens = True, \ | |
max_length = 256, \ | |
padding = "max_length", \ | |
truncation = True, | |
return_tensors='pt') | |
input_ids = inputs['input_ids'].to(device) | |
att_masks = inputs['attention_mask'].to(device) | |
# Get model predictions | |
with torch.no_grad(): | |
out = transformer(input_ids, attention_mask=att_masks) | |
logits = out.pooler_output | |
out = classifier(logits) | |
hs_out, abusive_out, target_out, strength_out, type_out = out[0], out[1], out[2], out[3], out[4] | |
hs_act, abusive_act, target_act, strength_act, type_act = act_sig(hs_out).squeeze(), \ | |
act_sig(abusive_out).squeeze(), act_sig(target_out).squeeze(0), act_soft(strength_out), act_sig(type_out).squeeze(0) | |
# Interpret the predictions | |
is_hate_speech = bool(hs_act >= 0.5) | |
is_abusive = bool(abusive_act >= 0.5) | |
hate_speech_target = int(target_act >= 0.5) | |
hate_speech_strength = strength_act.argmax().item() | |
if is_hate_speech: | |
hate_speech_target_label = target_names[hate_speech_target] | |
hate_speech_strength_label = strength_names[hate_speech_strength] | |
hate_speech_type_label = [] | |
print('target', target_act) | |
print('strength', strength_act) | |
for idx, prob in enumerate(type_act): | |
if prob >= 0.5: | |
hate_speech_type_label.append(type_names[idx]) | |
if len(hate_speech_type_label) == 0: | |
hate_speech_type_label.append("Other") | |
else: | |
hate_speech_target_label = "Non-HS" | |
hate_speech_strength_label = "Non-HS" | |
hate_speech_type_label = "Non-HS" | |
return is_hate_speech, is_abusive, hate_speech_target_label, hate_speech_strength_label, {"hs_type":hate_speech_type_label} | |
# Create the Gradio interface | |
iface = gr.Interface(fn=predict, inputs=gr.Textbox(label="Enter a sentence"), outputs=[ | |
gr.Label(label="Is Hate Speech"), | |
gr.Label(label="Is Abusive"), | |
gr.Label(label="Hate Speech Target"), | |
gr.Label(label="Hate Speech Strength"), | |
gr.JSON(label="Hate Speech Type") | |
], title="Hate Speech Detection") | |
iface.launch() # Launches the mini app! | |