Spaces:
Sleeping
Sleeping
Commit
·
f871f1a
1
Parent(s):
1457295
first draft of streamlit app
Browse files- app.py +210 -2
- requirements.txt +5 -0
app.py
CHANGED
@@ -1,4 +1,212 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
import time
|
5 |
+
import os
|
6 |
+
from dotenv import load_dotenv
|
7 |
|
8 |
+
# Load environment variables
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
# Set page configuration
|
12 |
+
st.set_page_config(
|
13 |
+
page_title="GemmaTextAppeal",
|
14 |
+
page_icon="✨",
|
15 |
+
layout="wide",
|
16 |
+
)
|
17 |
+
|
18 |
+
# App title and description
|
19 |
+
st.title("✨ GemmaTextAppeal")
|
20 |
+
st.markdown("""
|
21 |
+
### Interactive Demo of Google's Gemma 2-2B-IT Model
|
22 |
+
This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
|
23 |
+
Enter a prompt below and see the model generate text in real-time!
|
24 |
+
""")
|
25 |
+
|
26 |
+
# Sidebar with information
|
27 |
+
with st.sidebar:
|
28 |
+
st.header("About Gemma")
|
29 |
+
st.markdown("""
|
30 |
+
[Gemma 2-2B-IT](https://huggingface.co/google/gemma-2-2b-it) is a lightweight 2B parameter instruction-tuned model from Google's Gemma family.
|
31 |
+
|
32 |
+
Key features:
|
33 |
+
- Efficient text generation
|
34 |
+
- Strong instruction following
|
35 |
+
- 2 billion parameters - fast enough to run on consumer hardware
|
36 |
+
- Trained on a mixture of text and code
|
37 |
+
|
38 |
+
This demo runs directly on Hugging Face Spaces!
|
39 |
+
""")
|
40 |
+
|
41 |
+
st.header("Usage Tips")
|
42 |
+
st.markdown("""
|
43 |
+
- Be specific in your prompts
|
44 |
+
- You can ask for creative content, summaries, or answers to questions
|
45 |
+
- The model performs best when given clear instructions
|
46 |
+
- Try different temperatures to vary creativity vs. coherence
|
47 |
+
""")
|
48 |
+
|
49 |
+
st.header("Sample Prompts")
|
50 |
+
sample_prompts = [
|
51 |
+
"Write a short story about a robot discovering emotions",
|
52 |
+
"Explain quantum computing to a 10-year old",
|
53 |
+
"Create a recipe for vegan chocolate chip cookies",
|
54 |
+
"Write a haiku about artificial intelligence",
|
55 |
+
"Describe the benefits and risks of generative AI"
|
56 |
+
]
|
57 |
+
|
58 |
+
for i, prompt in enumerate(sample_prompts):
|
59 |
+
if st.button(f"Example {i+1}", key=f"sample_{i}"):
|
60 |
+
st.session_state.user_prompt = prompt
|
61 |
+
|
62 |
+
# Initialize session state variables
|
63 |
+
if 'user_prompt' not in st.session_state:
|
64 |
+
st.session_state.user_prompt = ""
|
65 |
+
if 'generation_complete' not in st.session_state:
|
66 |
+
st.session_state.generation_complete = False
|
67 |
+
if 'generated_text' not in st.session_state:
|
68 |
+
st.session_state.generated_text = ""
|
69 |
+
|
70 |
+
# Model parameters
|
71 |
+
col1, col2 = st.columns(2)
|
72 |
+
with col1:
|
73 |
+
max_length = st.slider("Maximum Length", min_value=50, max_value=1000, value=300, step=50,
|
74 |
+
help="Maximum number of tokens to generate")
|
75 |
+
with col2:
|
76 |
+
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.7, step=0.1,
|
77 |
+
help="Higher values make output more random, lower values more deterministic")
|
78 |
+
|
79 |
+
# User input
|
80 |
+
user_input = st.text_area("Enter your prompt:",
|
81 |
+
value=st.session_state.user_prompt,
|
82 |
+
height=100,
|
83 |
+
placeholder="e.g., Write a short story about a robot discovering emotions")
|
84 |
+
|
85 |
+
# Function to load model and generate text
|
86 |
+
@st.cache_resource
|
87 |
+
def load_model():
|
88 |
+
# Get API Token
|
89 |
+
huggingface_token = os.getenv("HF_TOKEN")
|
90 |
+
if not huggingface_token:
|
91 |
+
st.warning("No Hugging Face API token found. Some models may not be accessible.")
|
92 |
+
|
93 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token=huggingface_token)
|
94 |
+
model = AutoModelForCausalLM.from_pretrained(
|
95 |
+
"google/gemma-2-2b-it",
|
96 |
+
token=huggingface_token,
|
97 |
+
torch_dtype=torch.float16,
|
98 |
+
device_map="auto"
|
99 |
+
)
|
100 |
+
return tokenizer, model
|
101 |
+
|
102 |
+
def generate_text(prompt, max_new_tokens=300, temperature=0.7):
|
103 |
+
tokenizer, model = load_model()
|
104 |
+
|
105 |
+
# Format the prompt according to Gemma's expected format
|
106 |
+
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
107 |
+
|
108 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
109 |
+
|
110 |
+
# Create the progress bar
|
111 |
+
progress_bar = st.progress(0)
|
112 |
+
status_text = st.empty()
|
113 |
+
output_area = st.empty()
|
114 |
+
|
115 |
+
tokens_generated = 0
|
116 |
+
generated_text = ""
|
117 |
+
|
118 |
+
# Generate with streaming
|
119 |
+
streamer_output = ""
|
120 |
+
|
121 |
+
# Generate with step-by-step tracking for the progress bar
|
122 |
+
generate_kwargs = dict(
|
123 |
+
inputs=inputs["input_ids"],
|
124 |
+
max_new_tokens=max_new_tokens,
|
125 |
+
temperature=temperature,
|
126 |
+
do_sample=True,
|
127 |
+
pad_token_id=tokenizer.eos_token_id
|
128 |
+
)
|
129 |
+
|
130 |
+
status_text.text("Generating response...")
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
# Generate text step by step
|
134 |
+
for i in range(max_new_tokens):
|
135 |
+
if i == 0:
|
136 |
+
outputs = model.generate(
|
137 |
+
**generate_kwargs,
|
138 |
+
max_new_tokens=1,
|
139 |
+
)
|
140 |
+
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
|
141 |
+
else:
|
142 |
+
input_ids = torch.cat([inputs["input_ids"], generated_ids], dim=1)
|
143 |
+
outputs = model.generate(
|
144 |
+
input_ids=input_ids,
|
145 |
+
max_new_tokens=1,
|
146 |
+
do_sample=True,
|
147 |
+
temperature=temperature,
|
148 |
+
pad_token_id=tokenizer.eos_token_id
|
149 |
+
)
|
150 |
+
new_token = outputs[0][-1].unsqueeze(0)
|
151 |
+
generated_ids = torch.cat([generated_ids, new_token], dim=0)
|
152 |
+
|
153 |
+
# Decode text
|
154 |
+
current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
155 |
+
|
156 |
+
# Update streaming output
|
157 |
+
streamer_output = current_text
|
158 |
+
|
159 |
+
# Update progress and output
|
160 |
+
progress = min(1.0, (i + 1) / max_new_tokens)
|
161 |
+
progress_bar.progress(progress)
|
162 |
+
|
163 |
+
# Update display
|
164 |
+
output_area.markdown(f"**Generated Response:**\n\n{streamer_output}")
|
165 |
+
|
166 |
+
# Check if we've reached an end token
|
167 |
+
if generated_ids[-1].item() == tokenizer.eos_token_id:
|
168 |
+
break
|
169 |
+
|
170 |
+
# Add a small delay to simulate typing
|
171 |
+
time.sleep(0.01)
|
172 |
+
|
173 |
+
status_text.text("Generation complete!")
|
174 |
+
progress_bar.progress(1.0)
|
175 |
+
|
176 |
+
return streamer_output
|
177 |
+
|
178 |
+
# Generate button
|
179 |
+
if st.button("Generate Text"):
|
180 |
+
if user_input:
|
181 |
+
st.session_state.user_prompt = user_input
|
182 |
+
with st.spinner("Generating text..."):
|
183 |
+
st.session_state.generated_text = generate_text(user_input, max_length, temperature)
|
184 |
+
st.session_state.generation_complete = True
|
185 |
+
else:
|
186 |
+
st.error("Please enter a prompt first!")
|
187 |
+
|
188 |
+
# Display results
|
189 |
+
if st.session_state.generation_complete:
|
190 |
+
st.markdown("### Generated Text")
|
191 |
+
st.markdown(st.session_state.generated_text)
|
192 |
+
|
193 |
+
# Analysis section
|
194 |
+
with st.expander("Text Analysis"):
|
195 |
+
col1, col2 = st.columns(2)
|
196 |
+
with col1:
|
197 |
+
st.metric("Character Count", len(st.session_state.generated_text))
|
198 |
+
st.metric("Word Count", len(st.session_state.generated_text.split()))
|
199 |
+
with col2:
|
200 |
+
st.metric("Sentence Count", st.session_state.generated_text.count('.') +
|
201 |
+
st.session_state.generated_text.count('!') +
|
202 |
+
st.session_state.generated_text.count('?'))
|
203 |
+
st.metric("Paragraph Count", st.session_state.generated_text.count('\n\n') + 1)
|
204 |
+
|
205 |
+
# Footer
|
206 |
+
st.markdown("---")
|
207 |
+
st.markdown("""
|
208 |
+
<div style="text-align: center">
|
209 |
+
<p>Created with ❤️ | Powered by Gemma 2-2B-IT and Hugging Face</p>
|
210 |
+
<p>Code available on <a href="https://huggingface.co/spaces/your-username/GemmaTextAppeal">Hugging Face Spaces</a></p>
|
211 |
+
</div>
|
212 |
+
""", unsafe_allow_html=True)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.24.0
|
2 |
+
torch>=2.0.0
|
3 |
+
transformers>=4.31.0
|
4 |
+
python-dotenv==1.0.0
|
5 |
+
accelerate>=0.20.0
|