import os import gradio as gr from transformers import T5Tokenizer, T5ForConditionalGeneration # Load your fine-tuned mT5 model model_name = "Addaci/mT5-small-experiment-13-checkpoint-2790" tokenizer = T5Tokenizer.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name) def correct_htr(raw_htr_text): # Tokenize the input text inputs = tokenizer(raw_htr_text, return_tensors="pt", max_length=512, truncation=True) print("Tokenized Inputs for HTR Correction:", inputs) # Debugging # Generate corrected text with max_length and beam search outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True) print("Generated Output (Tokens) for HTR Correction:", outputs) # Debugging # Decode the output, skipping special tokens corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print("Decoded Output for HTR Correction:", corrected_text) # Debugging return corrected_text def summarize_text(legal_text): # Tokenize the input text with the summarization prompt inputs = tokenizer("summarize: " + legal_text, return_tensors="pt", max_length=512, truncation=True) print("Tokenized Inputs for Summarization:", inputs) # Debugging # Generate summary with beam search for better results outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True) print("Generated Summary (Tokens):", outputs) # Debugging # Decode the output, skipping special tokens summary = tokenizer.decode(outputs[0], skip_special_tokens=True) print("Decoded Summary:", summary) # Debugging return summary def answer_question(legal_text, question): # Format input for question-answering formatted_input = f"question: {question} context: {legal_text}" inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True) print("Tokenized Inputs for Question Answering:", inputs) # Debugging # Generate answer using beam search outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True) print("Generated Answer (Tokens):", outputs) # Debugging # Decode the output, skipping special tokens answer = tokenizer.decode(outputs[0], skip_special_tokens=True) print("Decoded Answer:", answer) # Debugging return answer