GemmaTextAppeal / app.py
BryanBradfo's picture
generate output as it comes
d198e0d
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Set page configuration
st.set_page_config(
page_title="GemmaTextAppeal",
page_icon="✨",
layout="wide",
)
# App title and description
st.title("✨ GemmaTextAppeal")
st.markdown("""
### Interactive Demo of Google's Gemma 2-2B-IT Model
This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
Enter a prompt below and see the model generate text in real-time!
""")
# Function to load model
@st.cache_resource(show_spinner=False)
def load_model():
try:
# Get API Token
huggingface_token = os.getenv("HF_TOKEN")
if not huggingface_token:
return None, None, "No Hugging Face API token found. Please add your token as a secret named 'HF_TOKEN'."
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"google/gemma-2-2b-it",
token=huggingface_token
)
# Load model with appropriate configuration
model_kwargs = {
"token": huggingface_token,
"device_map": "auto" if torch.cuda.is_available() else None,
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
}
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b-it",
**model_kwargs
)
return tokenizer, model, None
except Exception as e:
return None, None, str(e)
# Try to load the model at startup
with st.spinner("Initializing the Gemma model... this may take a minute."):
tokenizer, model, load_error = load_model()
if load_error:
st.error(f"Error loading model: {load_error}")
else:
if tokenizer and model:
st.success("✅ Gemma model loaded successfully! Ready to generate text.")
else:
st.warning("⚠️ Model not loaded. Please check your Hugging Face token.")
# Check for Hugging Face Token
huggingface_token = os.getenv("HF_TOKEN")
if not huggingface_token:
st.warning("""
⚠️ **No Hugging Face API token detected**
The Gemma models require accepting a license and authentication to use.
To make this app work:
1. Create a Hugging Face account
2. Accept the model license at: https://huggingface.co/google/gemma-2-2b-it
3. Create a HF token at: https://huggingface.co/settings/tokens
4. Add your token as a secret named 'HF_TOKEN' in your Space settings
""")
# Sidebar with information
with st.sidebar:
st.header("About Gemma")
st.markdown("""
[Gemma 2-2B-IT](https://huggingface.co/google/gemma-2-2b-it) is a lightweight 2B parameter instruction-tuned model from Google's Gemma family.
Key features:
- Efficient text generation
- Strong instruction following
- 2 billion parameters - fast enough to run on consumer hardware
- Trained on a mixture of text and code
This demo runs directly on Hugging Face Spaces!
""")
st.header("Usage Tips")
st.markdown("""
- Be specific in your prompts
- You can ask for creative content, summaries, or answers to questions
- The model performs best when given clear instructions
- Try different temperatures to vary creativity vs. coherence
""")
st.header("Sample Prompts")
sample_prompts = [
"Write a short story about a robot discovering emotions",
"Explain quantum computing to a 10-year old",
"Create a recipe for vegan chocolate chip cookies",
"Write a haiku about artificial intelligence",
"Describe the benefits and risks of generative AI"
]
for i, prompt in enumerate(sample_prompts):
if st.button(f"Example {i+1}", key=f"sample_{i}"):
st.session_state.user_prompt = prompt
# Initialize session state variables
if 'user_prompt' not in st.session_state:
st.session_state.user_prompt = ""
if 'generation_complete' not in st.session_state:
st.session_state.generation_complete = False
if 'generated_text' not in st.session_state:
st.session_state.generated_text = ""
if 'error_message' not in st.session_state:
st.session_state.error_message = None
# Model parameters
col1, col2 = st.columns(2)
with col1:
max_length = st.slider("Maximum Length", min_value=50, max_value=1000, value=300, step=50,
help="Maximum number of tokens to generate")
with col2:
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.7, step=0.1,
help="Higher values make output more random, lower values more deterministic")
# User input
user_input = st.text_area("Enter your prompt:",
value=st.session_state.user_prompt,
height=100,
placeholder="e.g., Write a short story about a robot discovering emotions")
def generate_text_streaming(prompt, max_new_tokens=300, temperature=0.7):
if not tokenizer or not model:
st.session_state.error_message = "Model not properly loaded. Please check your Hugging Face token."
return None
try:
# Format the prompt according to Gemma's expected format
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
# Create the output area
output_container = st.empty()
response_area = st.container()
with response_area:
st.markdown("**Generated Response:**")
response_text = st.empty()
# Tokenize the input
encoding = tokenizer(formatted_prompt, return_tensors="pt")
# Move to the appropriate device
if torch.cuda.is_available():
encoding = {k: v.to("cuda") for k, v in encoding.items()}
# Store the length of the input to track new tokens
input_length = encoding["input_ids"].shape[1]
# Initialize generated text container
generated_text = ""
# Generate tokens with streaming
generated_ids = []
# Set up generation configuration
for _ in range(max_new_tokens):
with torch.no_grad():
if len(generated_ids) == 0:
# First token generation
outputs = model.generate(
**encoding,
max_new_tokens=1,
do_sample=True,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=False
)
next_token_id = outputs.sequences[0, input_length:input_length+1]
else:
# Subsequent tokens
current_input_ids = torch.cat([encoding["input_ids"], torch.tensor([generated_ids], device=encoding["input_ids"].device)], dim=1)
outputs = model.generate(
input_ids=current_input_ids,
max_new_tokens=1,
do_sample=True,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=False
)
next_token_id = outputs.sequences[0, -1].unsqueeze(0)
# Convert to Python list and append
next_token_id_list = next_token_id.tolist()
generated_ids.extend(next_token_id_list)
# Check for EOS token
if tokenizer.eos_token_id in next_token_id_list:
break
# Decode the tokens generated so far and update the displayed text
current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
generated_text = current_text
response_text.markdown(generated_text)
return generated_text
except Exception as e:
st.session_state.error_message = f"Error during generation: {str(e)}"
st.error(f"Error during generation: {str(e)}")
return None
# Show any existing error
if st.session_state.error_message:
st.error(f"Error: {st.session_state.error_message}")
# Add troubleshooting information
with st.expander("Troubleshooting Information"):
st.markdown("""
### Common Issues:
1. **Missing Hugging Face Token**: The Gemma model requires authentication. Add your token as a secret named 'HF_TOKEN' in the Space settings.
2. **License Acceptance**: You need to accept the model license on the [Gemma model page](https://huggingface.co/google/gemma-2-2b-it).
3. **Internet Connection**: The model needs to be downloaded the first time the app runs. Ensure your Space has internet access.
4. **Resource Constraints**: The Gemma model requires significant resources. Consider upgrading your Space's hardware if you're encountering memory issues.
### How to Fix:
1. Create a [Hugging Face account](https://huggingface.co/join)
2. Visit the [Gemma model page](https://huggingface.co/google/gemma-2-2b-it) and accept the license
3. Create a token at https://huggingface.co/settings/tokens
4. Add your token to the Space: Settings → Secrets → New Secret (HF_TOKEN)
""")
# Add a debug section
with st.expander("Debug Information"):
st.write(f"Model loaded: {model is not None}")
st.write(f"Tokenizer loaded: {tokenizer is not None}")
st.write(f"Device: {model.device if model else 'N/A'}")
st.write(f"Hugging Face token set: {huggingface_token is not None}")
if torch.cuda.is_available():
st.write(f"CUDA available: True (Device count: {torch.cuda.device_count()})")
else:
st.write("CUDA available: False")
# Generate button
if st.button("Generate Text"):
# Reset any previous errors
st.session_state.error_message = None
if not huggingface_token:
st.error("Hugging Face token is required! Please add your token as described above.")
elif user_input:
st.session_state.user_prompt = user_input
result = generate_text_streaming(user_input, max_length, temperature)
if result is not None: # Only set if no error occurred
st.session_state.generated_text = result
st.session_state.generation_complete = True
else:
st.error("Please enter a prompt first!")
# Analysis section (only show after generation is complete)
if st.session_state.generation_complete and not st.session_state.error_message and st.session_state.generated_text:
# Analysis section
with st.expander("Text Analysis"):
col1, col2 = st.columns(2)
with col1:
st.metric("Character Count", len(st.session_state.generated_text))
st.metric("Word Count", len(st.session_state.generated_text.split()))
with col2:
st.metric("Sentence Count", st.session_state.generated_text.count('.') +
st.session_state.generated_text.count('!') +
st.session_state.generated_text.count('?'))
st.metric("Paragraph Count", st.session_state.generated_text.count('\n\n') + 1)
# Footer
st.markdown("---")
st.markdown("""
<div style="text-align: center">
<p>Created with ❤️ | Powered by Gemma 2-2B-IT and Hugging Face</p>
<p>Code available on <a href="https://huggingface.co/spaces" target="_blank">Hugging Face Spaces</a></p>
</div>
""", unsafe_allow_html=True)