Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Set page configuration | |
st.set_page_config( | |
page_title="GemmaTextAppeal", | |
page_icon="✨", | |
layout="wide", | |
) | |
# App title and description | |
st.title("✨ GemmaTextAppeal") | |
st.markdown(""" | |
### Interactive Demo of Google's Gemma 2-2B-IT Model | |
This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model. | |
Enter a prompt below and see the model generate text in real-time! | |
""") | |
# Function to load model | |
def load_model(): | |
try: | |
# Get API Token | |
huggingface_token = os.getenv("HF_TOKEN") | |
if not huggingface_token: | |
return None, None, "No Hugging Face API token found. Please add your token as a secret named 'HF_TOKEN'." | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
"google/gemma-2-2b-it", | |
token=huggingface_token | |
) | |
# Load model with appropriate configuration | |
model_kwargs = { | |
"token": huggingface_token, | |
"device_map": "auto" if torch.cuda.is_available() else None, | |
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32 | |
} | |
model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2-2b-it", | |
**model_kwargs | |
) | |
return tokenizer, model, None | |
except Exception as e: | |
return None, None, str(e) | |
# Try to load the model at startup | |
with st.spinner("Initializing the Gemma model... this may take a minute."): | |
tokenizer, model, load_error = load_model() | |
if load_error: | |
st.error(f"Error loading model: {load_error}") | |
else: | |
if tokenizer and model: | |
st.success("✅ Gemma model loaded successfully! Ready to generate text.") | |
else: | |
st.warning("⚠️ Model not loaded. Please check your Hugging Face token.") | |
# Check for Hugging Face Token | |
huggingface_token = os.getenv("HF_TOKEN") | |
if not huggingface_token: | |
st.warning(""" | |
⚠️ **No Hugging Face API token detected** | |
The Gemma models require accepting a license and authentication to use. | |
To make this app work: | |
1. Create a Hugging Face account | |
2. Accept the model license at: https://huggingface.co/google/gemma-2-2b-it | |
3. Create a HF token at: https://huggingface.co/settings/tokens | |
4. Add your token as a secret named 'HF_TOKEN' in your Space settings | |
""") | |
# Sidebar with information | |
with st.sidebar: | |
st.header("About Gemma") | |
st.markdown(""" | |
[Gemma 2-2B-IT](https://huggingface.co/google/gemma-2-2b-it) is a lightweight 2B parameter instruction-tuned model from Google's Gemma family. | |
Key features: | |
- Efficient text generation | |
- Strong instruction following | |
- 2 billion parameters - fast enough to run on consumer hardware | |
- Trained on a mixture of text and code | |
This demo runs directly on Hugging Face Spaces! | |
""") | |
st.header("Usage Tips") | |
st.markdown(""" | |
- Be specific in your prompts | |
- You can ask for creative content, summaries, or answers to questions | |
- The model performs best when given clear instructions | |
- Try different temperatures to vary creativity vs. coherence | |
""") | |
st.header("Sample Prompts") | |
sample_prompts = [ | |
"Write a short story about a robot discovering emotions", | |
"Explain quantum computing to a 10-year old", | |
"Create a recipe for vegan chocolate chip cookies", | |
"Write a haiku about artificial intelligence", | |
"Describe the benefits and risks of generative AI" | |
] | |
for i, prompt in enumerate(sample_prompts): | |
if st.button(f"Example {i+1}", key=f"sample_{i}"): | |
st.session_state.user_prompt = prompt | |
# Initialize session state variables | |
if 'user_prompt' not in st.session_state: | |
st.session_state.user_prompt = "" | |
if 'generation_complete' not in st.session_state: | |
st.session_state.generation_complete = False | |
if 'generated_text' not in st.session_state: | |
st.session_state.generated_text = "" | |
if 'error_message' not in st.session_state: | |
st.session_state.error_message = None | |
# Model parameters | |
col1, col2 = st.columns(2) | |
with col1: | |
max_length = st.slider("Maximum Length", min_value=50, max_value=1000, value=300, step=50, | |
help="Maximum number of tokens to generate") | |
with col2: | |
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.7, step=0.1, | |
help="Higher values make output more random, lower values more deterministic") | |
# User input | |
user_input = st.text_area("Enter your prompt:", | |
value=st.session_state.user_prompt, | |
height=100, | |
placeholder="e.g., Write a short story about a robot discovering emotions") | |
def generate_text_streaming(prompt, max_new_tokens=300, temperature=0.7): | |
if not tokenizer or not model: | |
st.session_state.error_message = "Model not properly loaded. Please check your Hugging Face token." | |
return None | |
try: | |
# Format the prompt according to Gemma's expected format | |
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n" | |
# Create the output area | |
output_container = st.empty() | |
response_area = st.container() | |
with response_area: | |
st.markdown("**Generated Response:**") | |
response_text = st.empty() | |
# Tokenize the input | |
encoding = tokenizer(formatted_prompt, return_tensors="pt") | |
# Move to the appropriate device | |
if torch.cuda.is_available(): | |
encoding = {k: v.to("cuda") for k, v in encoding.items()} | |
# Store the length of the input to track new tokens | |
input_length = encoding["input_ids"].shape[1] | |
# Initialize generated text container | |
generated_text = "" | |
# Generate tokens with streaming | |
generated_ids = [] | |
# Set up generation configuration | |
for _ in range(max_new_tokens): | |
with torch.no_grad(): | |
if len(generated_ids) == 0: | |
# First token generation | |
outputs = model.generate( | |
**encoding, | |
max_new_tokens=1, | |
do_sample=True, | |
temperature=temperature, | |
pad_token_id=tokenizer.eos_token_id, | |
return_dict_in_generate=True, | |
output_scores=False | |
) | |
next_token_id = outputs.sequences[0, input_length:input_length+1] | |
else: | |
# Subsequent tokens | |
current_input_ids = torch.cat([encoding["input_ids"], torch.tensor([generated_ids], device=encoding["input_ids"].device)], dim=1) | |
outputs = model.generate( | |
input_ids=current_input_ids, | |
max_new_tokens=1, | |
do_sample=True, | |
temperature=temperature, | |
pad_token_id=tokenizer.eos_token_id, | |
return_dict_in_generate=True, | |
output_scores=False | |
) | |
next_token_id = outputs.sequences[0, -1].unsqueeze(0) | |
# Convert to Python list and append | |
next_token_id_list = next_token_id.tolist() | |
generated_ids.extend(next_token_id_list) | |
# Check for EOS token | |
if tokenizer.eos_token_id in next_token_id_list: | |
break | |
# Decode the tokens generated so far and update the displayed text | |
current_text = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
generated_text = current_text | |
response_text.markdown(generated_text) | |
return generated_text | |
except Exception as e: | |
st.session_state.error_message = f"Error during generation: {str(e)}" | |
st.error(f"Error during generation: {str(e)}") | |
return None | |
# Show any existing error | |
if st.session_state.error_message: | |
st.error(f"Error: {st.session_state.error_message}") | |
# Add troubleshooting information | |
with st.expander("Troubleshooting Information"): | |
st.markdown(""" | |
### Common Issues: | |
1. **Missing Hugging Face Token**: The Gemma model requires authentication. Add your token as a secret named 'HF_TOKEN' in the Space settings. | |
2. **License Acceptance**: You need to accept the model license on the [Gemma model page](https://huggingface.co/google/gemma-2-2b-it). | |
3. **Internet Connection**: The model needs to be downloaded the first time the app runs. Ensure your Space has internet access. | |
4. **Resource Constraints**: The Gemma model requires significant resources. Consider upgrading your Space's hardware if you're encountering memory issues. | |
### How to Fix: | |
1. Create a [Hugging Face account](https://huggingface.co/join) | |
2. Visit the [Gemma model page](https://huggingface.co/google/gemma-2-2b-it) and accept the license | |
3. Create a token at https://huggingface.co/settings/tokens | |
4. Add your token to the Space: Settings → Secrets → New Secret (HF_TOKEN) | |
""") | |
# Add a debug section | |
with st.expander("Debug Information"): | |
st.write(f"Model loaded: {model is not None}") | |
st.write(f"Tokenizer loaded: {tokenizer is not None}") | |
st.write(f"Device: {model.device if model else 'N/A'}") | |
st.write(f"Hugging Face token set: {huggingface_token is not None}") | |
if torch.cuda.is_available(): | |
st.write(f"CUDA available: True (Device count: {torch.cuda.device_count()})") | |
else: | |
st.write("CUDA available: False") | |
# Generate button | |
if st.button("Generate Text"): | |
# Reset any previous errors | |
st.session_state.error_message = None | |
if not huggingface_token: | |
st.error("Hugging Face token is required! Please add your token as described above.") | |
elif user_input: | |
st.session_state.user_prompt = user_input | |
result = generate_text_streaming(user_input, max_length, temperature) | |
if result is not None: # Only set if no error occurred | |
st.session_state.generated_text = result | |
st.session_state.generation_complete = True | |
else: | |
st.error("Please enter a prompt first!") | |
# Analysis section (only show after generation is complete) | |
if st.session_state.generation_complete and not st.session_state.error_message and st.session_state.generated_text: | |
# Analysis section | |
with st.expander("Text Analysis"): | |
col1, col2 = st.columns(2) | |
with col1: | |
st.metric("Character Count", len(st.session_state.generated_text)) | |
st.metric("Word Count", len(st.session_state.generated_text.split())) | |
with col2: | |
st.metric("Sentence Count", st.session_state.generated_text.count('.') + | |
st.session_state.generated_text.count('!') + | |
st.session_state.generated_text.count('?')) | |
st.metric("Paragraph Count", st.session_state.generated_text.count('\n\n') + 1) | |
# Footer | |
st.markdown("---") | |
st.markdown(""" | |
<div style="text-align: center"> | |
<p>Created with ❤️ | Powered by Gemma 2-2B-IT and Hugging Face</p> | |
<p>Code available on <a href="https://huggingface.co/spaces" target="_blank">Hugging Face Spaces</a></p> | |
</div> | |
""", unsafe_allow_html=True) |