import os import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import traceback # Use Hugging Face Spaces' recommended persistent storage CACHE_DIR = os.path.join(os.getcwd(), "model_cache") def ensure_cache_dir(): """ Ensure the cache directory exists. Returns: str: Path to the cache directory """ os.makedirs(CACHE_DIR, exist_ok=True) return CACHE_DIR def load_model_and_tokenizer(model_name): """ Load model and tokenizer with persistent caching. Args: model_name (str): Name of the model to load Returns: tuple: (model, tokenizer) """ try: # Ensure cache directory exists cache_dir = ensure_cache_dir() # Construct full cache path for this model model_cache_path = os.path.join(cache_dir, model_name.replace('/', '_')) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( model_name, cache_dir=model_cache_path ) # Load model model = AutoModelForSeq2SeqLM.from_pretrained( model_name, cache_dir=model_cache_path ) return model, tokenizer except Exception as e: st.error(f"Error loading {model_name}: {str(e)}") st.error(traceback.format_exc()) return None, None def generate_summary(model, tokenizer, text, max_length=150): """ Generate summary using a specific model and tokenizer. Args: model: Hugging Face model tokenizer: Hugging Face tokenizer text (str): Input text to summarize max_length (int): Maximum length of summary Returns: str: Generated summary """ try: # Prepare input inputs = tokenizer( f"summarize: {text}", max_length=512, return_tensors="pt", truncation=True ) # Generate summary summary_ids = model.generate( inputs.input_ids, num_beams=4, max_length=max_length, early_stopping=True ) # Decode summary summary = tokenizer.decode( summary_ids[0], skip_special_tokens=True ) return summary except Exception as e: error_msg = f"Error in summarization: {str(e)}" st.error(error_msg) return error_msg def main(): st.title("Text Summarization with Pre-trained Models") # Display cache directory info (optional) st.info(f"Models will be cached in: {CACHE_DIR}") # Define models models_to_load = { 'BART': 'facebook/bart-large-cnn', 'T5': 't5-large', 'Pegasus': 'google/pegasus-cnn_dailymail' } # Text input text_input = st.text_area("Enter text to summarize:") # Generate button if st.button("Generate Summary"): if not text_input: st.error("Please enter text to summarize.") return # Create columns for progressive display bart_col, t5_col, pegasus_col = st.columns(3) # Function to process each model def process_model(col, model_name, model_path): with col: with st.spinner(f'Generating {model_name} Summary...'): progress = st.progress(0) progress.progress(50) # Load model and tokenizer model, tokenizer = load_model_and_tokenizer(model_path) if model and tokenizer: # Generate summary summary = generate_summary(model, tokenizer, text_input) progress.progress(100) st.subheader(f"{model_name} Summary") st.write(summary) else: st.error(f"Failed to load {model_name} model") # Process each model process_model(bart_col, 'BART', 'facebook/bart-large-cnn') process_model(t5_col, 'T5', 't5-large') process_model(pegasus_col, 'Pegasus', 'google/pegasus-cnn_dailymail') if __name__ == "__main__": main()