import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM import os from dotenv import load_dotenv # Load environment variables load_dotenv() # Set page configuration st.set_page_config( page_title="GemmaTextAppeal", page_icon="✨", layout="wide", ) # App title and description st.title("✨ GemmaTextAppeal") st.markdown(""" ### Interactive Demo of Google's Gemma 2-2B-IT Model This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model. Enter a prompt below and see the model generate text in real-time! """) # Function to load model @st.cache_resource(show_spinner=False) def load_model(): try: # Get API Token huggingface_token = os.getenv("HF_TOKEN") if not huggingface_token: return None, None, "No Hugging Face API token found. Please add your token as a secret named 'HF_TOKEN'." # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( "google/gemma-2-2b-it", token=huggingface_token ) # Load model with appropriate configuration model_kwargs = { "token": huggingface_token, "device_map": "auto" if torch.cuda.is_available() else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32 } model = AutoModelForCausalLM.from_pretrained( "google/gemma-2-2b-it", **model_kwargs ) return tokenizer, model, None except Exception as e: return None, None, str(e) # Try to load the model at startup with st.spinner("Initializing the Gemma model... this may take a minute."): tokenizer, model, load_error = load_model() if load_error: st.error(f"Error loading model: {load_error}") else: if tokenizer and model: st.success("✅ Gemma model loaded successfully! Ready to generate text.") else: st.warning("⚠️ Model not loaded. Please check your Hugging Face token.") # Check for Hugging Face Token huggingface_token = os.getenv("HF_TOKEN") if not huggingface_token: st.warning(""" ⚠️ **No Hugging Face API token detected** The Gemma models require accepting a license and authentication to use. To make this app work: 1. Create a Hugging Face account 2. Accept the model license at: https://huggingface.co/google/gemma-2-2b-it 3. Create a HF token at: https://huggingface.co/settings/tokens 4. Add your token as a secret named 'HF_TOKEN' in your Space settings """) # Sidebar with information with st.sidebar: st.header("About Gemma") st.markdown(""" [Gemma 2-2B-IT](https://huggingface.co/google/gemma-2-2b-it) is a lightweight 2B parameter instruction-tuned model from Google's Gemma family. Key features: - Efficient text generation - Strong instruction following - 2 billion parameters - fast enough to run on consumer hardware - Trained on a mixture of text and code This demo runs directly on Hugging Face Spaces! """) st.header("Usage Tips") st.markdown(""" - Be specific in your prompts - You can ask for creative content, summaries, or answers to questions - The model performs best when given clear instructions - Try different temperatures to vary creativity vs. coherence """) st.header("Sample Prompts") sample_prompts = [ "Write a short story about a robot discovering emotions", "Explain quantum computing to a 10-year old", "Create a recipe for vegan chocolate chip cookies", "Write a haiku about artificial intelligence", "Describe the benefits and risks of generative AI" ] for i, prompt in enumerate(sample_prompts): if st.button(f"Example {i+1}", key=f"sample_{i}"): st.session_state.user_prompt = prompt # Initialize session state variables if 'user_prompt' not in st.session_state: st.session_state.user_prompt = "" if 'generation_complete' not in st.session_state: st.session_state.generation_complete = False if 'generated_text' not in st.session_state: st.session_state.generated_text = "" if 'error_message' not in st.session_state: st.session_state.error_message = None # Model parameters col1, col2 = st.columns(2) with col1: max_length = st.slider("Maximum Length", min_value=50, max_value=1000, value=300, step=50, help="Maximum number of tokens to generate") with col2: temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.7, step=0.1, help="Higher values make output more random, lower values more deterministic") # User input user_input = st.text_area("Enter your prompt:", value=st.session_state.user_prompt, height=100, placeholder="e.g., Write a short story about a robot discovering emotions") def generate_text_streaming(prompt, max_new_tokens=300, temperature=0.7): if not tokenizer or not model: st.session_state.error_message = "Model not properly loaded. Please check your Hugging Face token." return None try: # Format the prompt according to Gemma's expected format formatted_prompt = f"user\n{prompt}\nmodel\n" # Create the output area output_container = st.empty() response_area = st.container() with response_area: st.markdown("**Generated Response:**") response_text = st.empty() # Tokenize the input encoding = tokenizer(formatted_prompt, return_tensors="pt") # Move to the appropriate device if torch.cuda.is_available(): encoding = {k: v.to("cuda") for k, v in encoding.items()} # Store the length of the input to track new tokens input_length = encoding["input_ids"].shape[1] # Initialize generated text container generated_text = "" # Generate tokens with streaming generated_ids = [] # Set up generation configuration for _ in range(max_new_tokens): with torch.no_grad(): if len(generated_ids) == 0: # First token generation outputs = model.generate( **encoding, max_new_tokens=1, do_sample=True, temperature=temperature, pad_token_id=tokenizer.eos_token_id, return_dict_in_generate=True, output_scores=False ) next_token_id = outputs.sequences[0, input_length:input_length+1] else: # Subsequent tokens current_input_ids = torch.cat([encoding["input_ids"], torch.tensor([generated_ids], device=encoding["input_ids"].device)], dim=1) outputs = model.generate( input_ids=current_input_ids, max_new_tokens=1, do_sample=True, temperature=temperature, pad_token_id=tokenizer.eos_token_id, return_dict_in_generate=True, output_scores=False ) next_token_id = outputs.sequences[0, -1].unsqueeze(0) # Convert to Python list and append next_token_id_list = next_token_id.tolist() generated_ids.extend(next_token_id_list) # Check for EOS token if tokenizer.eos_token_id in next_token_id_list: break # Decode the tokens generated so far and update the displayed text current_text = tokenizer.decode(generated_ids, skip_special_tokens=True) generated_text = current_text response_text.markdown(generated_text) return generated_text except Exception as e: st.session_state.error_message = f"Error during generation: {str(e)}" st.error(f"Error during generation: {str(e)}") return None # Show any existing error if st.session_state.error_message: st.error(f"Error: {st.session_state.error_message}") # Add troubleshooting information with st.expander("Troubleshooting Information"): st.markdown(""" ### Common Issues: 1. **Missing Hugging Face Token**: The Gemma model requires authentication. Add your token as a secret named 'HF_TOKEN' in the Space settings. 2. **License Acceptance**: You need to accept the model license on the [Gemma model page](https://huggingface.co/google/gemma-2-2b-it). 3. **Internet Connection**: The model needs to be downloaded the first time the app runs. Ensure your Space has internet access. 4. **Resource Constraints**: The Gemma model requires significant resources. Consider upgrading your Space's hardware if you're encountering memory issues. ### How to Fix: 1. Create a [Hugging Face account](https://huggingface.co/join) 2. Visit the [Gemma model page](https://huggingface.co/google/gemma-2-2b-it) and accept the license 3. Create a token at https://huggingface.co/settings/tokens 4. Add your token to the Space: Settings → Secrets → New Secret (HF_TOKEN) """) # Add a debug section with st.expander("Debug Information"): st.write(f"Model loaded: {model is not None}") st.write(f"Tokenizer loaded: {tokenizer is not None}") st.write(f"Device: {model.device if model else 'N/A'}") st.write(f"Hugging Face token set: {huggingface_token is not None}") if torch.cuda.is_available(): st.write(f"CUDA available: True (Device count: {torch.cuda.device_count()})") else: st.write("CUDA available: False") # Generate button if st.button("Generate Text"): # Reset any previous errors st.session_state.error_message = None if not huggingface_token: st.error("Hugging Face token is required! Please add your token as described above.") elif user_input: st.session_state.user_prompt = user_input result = generate_text_streaming(user_input, max_length, temperature) if result is not None: # Only set if no error occurred st.session_state.generated_text = result st.session_state.generation_complete = True else: st.error("Please enter a prompt first!") # Analysis section (only show after generation is complete) if st.session_state.generation_complete and not st.session_state.error_message and st.session_state.generated_text: # Analysis section with st.expander("Text Analysis"): col1, col2 = st.columns(2) with col1: st.metric("Character Count", len(st.session_state.generated_text)) st.metric("Word Count", len(st.session_state.generated_text.split())) with col2: st.metric("Sentence Count", st.session_state.generated_text.count('.') + st.session_state.generated_text.count('!') + st.session_state.generated_text.count('?')) st.metric("Paragraph Count", st.session_state.generated_text.count('\n\n') + 1) # Footer st.markdown("---") st.markdown("""

Created with ❤️ | Powered by Gemma 2-2B-IT and Hugging Face

Code available on Hugging Face Spaces

""", unsafe_allow_html=True)