Spaces:
Runtime error
Runtime error
File size: 3,458 Bytes
597f19c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import time
import sys
import traceback
# Global variables to store error information
error_message = ""
# Global variables for model and tokenizer
model = None
tokenizer = None
device = None
# Load the model and tokenizer from Hugging Face
model_name = "ambrosfitz/history-qa-flan-t5-large"
try:
global model, tokenizer, device
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
except Exception as e:
error_message = f"Error loading model or tokenizer: {str(e)}\n{traceback.format_exc()}"
print(error_message)
def generate_qa(text, max_length=512):
global model, tokenizer, device
try:
input_text = f"Generate a history question and answer based on this text: {text}"
input_ids = tokenizer(input_text, return_tensors="pt", max_length=max_length, truncation=True).input_ids.to(device)
with torch.no_grad():
outputs = model.generate(input_ids, max_length=max_length, num_return_sequences=1, do_sample=True, temperature=0.7)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Parse the generated text
parts = generated_text.split("Question: ")
if len(parts) > 1:
qa_parts = parts[1].split("Answer: ")
question = qa_parts[0].strip()
answer = qa_parts[1].strip() if len(qa_parts) > 1 else "No answer provided."
return f"Question: {question}\n\nAnswer: {answer}"
else:
return "Unable to generate a proper question and answer. Please try again with a different input."
except Exception as e:
return f"An error occurred: {str(e)}\n{traceback.format_exc()}"
def slow_qa(message, history):
try:
full_response = generate_qa(message)
for i in range(len(full_response)):
time.sleep(0.01)
yield full_response[:i+1]
except Exception as e:
yield f"An error occurred: {str(e)}\n{traceback.format_exc()}"
# Create and launch the Gradio interface
try:
iface = gr.ChatInterface(
slow_qa,
chatbot=gr.Chatbot(height=500),
textbox=gr.Textbox(placeholder="Enter historical text here...", container=False, scale=7),
title="History Q&A Generator (FLAN-T5)",
description="Enter a piece of historical text, and the model will generate a related question and answer.",
theme="soft",
examples=[
"The American Revolution was a colonial revolt that took place between 1765 and 1783.",
"World War II was a global conflict that lasted from 1939 to 1945, involving many of the world's nations.",
"The Renaissance was a period of cultural, artistic, political, and economic revival following the Middle Ages."
],
cache_examples=False,
retry_btn="Regenerate",
undo_btn="Remove last",
clear_btn="Clear",
)
if error_message:
print("Launching interface with error message.")
else:
print("Launching interface normally.")
iface.launch(debug=True)
except Exception as e:
print(f"An error occurred while creating or launching the interface: {str(e)}\n{traceback.format_exc()}") |