Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, Pipeline | |
from inference_utils import prepare_stance_texts | |
from models import StanceEncoderModel | |
CLASS_DICT = {0: 'FAVOR', 1: 'AGAINST', 2: 'NEITHER'} | |
params = {'lang': 'pl', | |
'masked_lm_prompt': 4, } | |
class StancePipeline(Pipeline): | |
def _sanitize_parameters(self, **pipeline_parameters): | |
return pipeline_parameters, {}, {} | |
def preprocess(self, input): | |
prompt_text, prompt_target = prepare_stance_texts([input['text'], ], [input['target'], ], params, | |
self.tokenizer) | |
inputs = self.tokenizer(prompt_text, prompt_target, return_tensors="pt", padding=True, truncation='only_first') | |
return {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'], | |
'sequence_ids': torch.tensor((np.array(inputs.sequence_ids()) == 1).astype(int)).unsqueeze(0)} | |
def _forward(self, model_inputs): | |
outputs = self.model(**model_inputs) | |
return outputs | |
def postprocess(self, model_outputs): | |
probas = model_outputs["logits"].softmax(-1) | |
score = probas.max(-1)[0].item() | |
return {'stance': CLASS_DICT[probas.argmax(-1).item()], 'score': score} | |
model = StanceEncoderModel.from_pretrained('clarin-knext/stance-pl-1', | |
use_auth_token=os.environ['TOKEN']) | |
tokenizer = AutoTokenizer.from_pretrained('clarin-knext/stance-pl-1', | |
use_auth_token=os.environ['TOKEN']) | |
pipeline = StancePipeline(model=model, tokenizer=tokenizer, batch_size=1) | |
def predict(text, target): | |
predictions = pipeline({'text': text, 'target': target}) | |
return f'{predictions["stance"]} ({predictions["score"]:.3f})' | |
gradio_app = gr.Interface( | |
predict, | |
inputs=[gr.TextArea(label="Text", placeholder="Enter text here..."), | |
gr.Textbox(label="Target", placeholder="Enter stance target here...")], | |
outputs=[gr.Label(label="Stance class")], | |
title="Polish stance detection", | |
) | |
if __name__ == "__main__": | |
gradio_app.launch() | |