import gradio as gr import torch from transformers import T5ForConditionalGeneration, T5TokenizerFast tokenizer = T5TokenizerFast.from_pretrained("t5-base") # Define the quantized model architecture quantized_model = T5ForConditionalGeneration.from_pretrained("t5-base") # Load the state dictionary state_dict = torch.load("quantized_model.pt") # Filter out keys that are not present in the quantized model filtered_state_dict = {k: v for k, v in state_dict.items() if k in quantized_model.state_dict()} # Load the filtered state dictionary into the quantized model quantized_model.load_state_dict(filtered_state_dict, strict=False) def encode_text(text): encoding = tokenizer.encode_plus( text, max_length=512, padding="max_length", truncation=True, return_attention_mask=True, return_tensors='pt' ) return encoding["input_ids"], encoding["attention_mask"] def generate_summary(input_ids, attention_mask, model): model = model.to(input_ids.device) generated_ids = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=150, num_beams=2, repetition_penalty=2.5, length_penalty=1.0, early_stopping=True ) return generated_ids def decode_summary(generated_ids): summary = [tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for gen_id in generated_ids] return "".join(summary) def summarize(text): input_ids, attention_mask = encode_text(text) generated_ids = generate_summary(input_ids, attention_mask, quantized_model) summary = decode_summary(generated_ids) return summary # Create Gradio interface input_text = gr.Textbox(lines=10, label="Input Text") output_text = gr.Textbox(label="Summary") gr.Interface( fn=summarize, inputs=input_text, outputs=output_text, title="Poem Pulse", description="Enter a Poem and get its Jist." ).launch()