Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from transformers import T5ForConditionalGeneration, AutoTokenizer | |
# from transformers import pipeline | |
auth_token = os.environ.get("CLARIN_KNEXT") | |
model_name = "clarin-knext/plt5-large-poquad" # "clarin-knext/plt5-large-poquad-ext-qa-autotoken" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token) | |
model = T5ForConditionalGeneration.from_pretrained(model_name, use_auth_token=auth_token) | |
default_generate_kwargs = { | |
"max_length": 192, | |
"num_beams": 2, | |
"length_penalty": 0, | |
"early_stopping": True, | |
} | |
# keywords_pipe = pipeline(model=model, tokenizer=tokenizer, **default_generate_kwargs) | |
examples = [ | |
["Jakie miasto jest stolicą Polski?", "Polska ma wiele wspaniałych miast, Wrocław, Poznań czy Gdańsk. Jednak stolicą jest Warszawa."]] | |
def generate(question, context): | |
context = f"question: {question} context: {context} </s>" | |
inputs = tokenizer( | |
context, | |
max_length=512, | |
add_special_tokens=True, | |
truncation=True, | |
padding=False, | |
return_tensors="pt" | |
) | |
outs = model.generate( | |
input_ids=inputs['input_ids'], | |
attention_mask=inputs['attention_mask'], | |
**default_generate_kwargs | |
) | |
prediction = tokenizer.decode(outs[0], skip_special_tokens=True) | |
return prediction | |
demo = gr.Interface( | |
fn=generate, | |
inputs=[gr.Textbox(lines=1, label="Question"), gr.Textbox(lines=5, label="Context")], | |
outputs=gr.Textbox(label="Answer"), | |
examples=examples, | |
) | |
demo.launch() |