import gradio as gr
import torch
from transformers import T5ForConditionalGeneration, T5TokenizerFast

tokenizer = T5TokenizerFast.from_pretrained("t5-base")

# Define the quantized model architecture
quantized_model = T5ForConditionalGeneration.from_pretrained("t5-base")

# Load the state dictionary
state_dict = torch.load("quantized_model.pt")

# Filter out keys that are not present in the quantized model
filtered_state_dict = {k: v for k, v in state_dict.items() if k in quantized_model.state_dict()}

# Load the filtered state dictionary into the quantized model
quantized_model.load_state_dict(filtered_state_dict, strict=False)

def encode_text(text):
    encoding = tokenizer.encode_plus(
        text,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    return encoding["input_ids"], encoding["attention_mask"]

def generate_summary(input_ids, attention_mask, model):
    model = model.to(input_ids.device)
    generated_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=150,
        num_beams=2,
        repetition_penalty=2.5,
        length_penalty=1.0,
        early_stopping=True
    )
    return generated_ids

def decode_summary(generated_ids):
    summary = [tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
               for gen_id in generated_ids]
    return "".join(summary)

def summarize(text):
    input_ids, attention_mask = encode_text(text)
    generated_ids = generate_summary(input_ids, attention_mask, quantized_model)
    summary = decode_summary(generated_ids)
    return summary

# Create Gradio interface
input_text = gr.Textbox(lines=10, label="Input Text")
output_text = gr.Textbox(label="Summary")

gr.Interface(
    fn=summarize,
    inputs=input_text,
    outputs=output_text,
    title="Poem Pulse",
    description="Enter a Poem and get its Jist."
).launch()