BryanBradfo commited on
Commit
f871f1a
·
1 Parent(s): 1457295

first draft of streamlit app

Browse files
Files changed (2) hide show
  1. app.py +210 -2
  2. requirements.txt +5 -0
app.py CHANGED
@@ -1,4 +1,212 @@
1
  import streamlit as st
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import time
5
+ import os
6
+ from dotenv import load_dotenv
7
 
8
+ # Load environment variables
9
+ load_dotenv()
10
+
11
+ # Set page configuration
12
+ st.set_page_config(
13
+ page_title="GemmaTextAppeal",
14
+ page_icon="✨",
15
+ layout="wide",
16
+ )
17
+
18
+ # App title and description
19
+ st.title("✨ GemmaTextAppeal")
20
+ st.markdown("""
21
+ ### Interactive Demo of Google's Gemma 2-2B-IT Model
22
+ This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
23
+ Enter a prompt below and see the model generate text in real-time!
24
+ """)
25
+
26
+ # Sidebar with information
27
+ with st.sidebar:
28
+ st.header("About Gemma")
29
+ st.markdown("""
30
+ [Gemma 2-2B-IT](https://huggingface.co/google/gemma-2-2b-it) is a lightweight 2B parameter instruction-tuned model from Google's Gemma family.
31
+
32
+ Key features:
33
+ - Efficient text generation
34
+ - Strong instruction following
35
+ - 2 billion parameters - fast enough to run on consumer hardware
36
+ - Trained on a mixture of text and code
37
+
38
+ This demo runs directly on Hugging Face Spaces!
39
+ """)
40
+
41
+ st.header("Usage Tips")
42
+ st.markdown("""
43
+ - Be specific in your prompts
44
+ - You can ask for creative content, summaries, or answers to questions
45
+ - The model performs best when given clear instructions
46
+ - Try different temperatures to vary creativity vs. coherence
47
+ """)
48
+
49
+ st.header("Sample Prompts")
50
+ sample_prompts = [
51
+ "Write a short story about a robot discovering emotions",
52
+ "Explain quantum computing to a 10-year old",
53
+ "Create a recipe for vegan chocolate chip cookies",
54
+ "Write a haiku about artificial intelligence",
55
+ "Describe the benefits and risks of generative AI"
56
+ ]
57
+
58
+ for i, prompt in enumerate(sample_prompts):
59
+ if st.button(f"Example {i+1}", key=f"sample_{i}"):
60
+ st.session_state.user_prompt = prompt
61
+
62
+ # Initialize session state variables
63
+ if 'user_prompt' not in st.session_state:
64
+ st.session_state.user_prompt = ""
65
+ if 'generation_complete' not in st.session_state:
66
+ st.session_state.generation_complete = False
67
+ if 'generated_text' not in st.session_state:
68
+ st.session_state.generated_text = ""
69
+
70
+ # Model parameters
71
+ col1, col2 = st.columns(2)
72
+ with col1:
73
+ max_length = st.slider("Maximum Length", min_value=50, max_value=1000, value=300, step=50,
74
+ help="Maximum number of tokens to generate")
75
+ with col2:
76
+ temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.7, step=0.1,
77
+ help="Higher values make output more random, lower values more deterministic")
78
+
79
+ # User input
80
+ user_input = st.text_area("Enter your prompt:",
81
+ value=st.session_state.user_prompt,
82
+ height=100,
83
+ placeholder="e.g., Write a short story about a robot discovering emotions")
84
+
85
+ # Function to load model and generate text
86
+ @st.cache_resource
87
+ def load_model():
88
+ # Get API Token
89
+ huggingface_token = os.getenv("HF_TOKEN")
90
+ if not huggingface_token:
91
+ st.warning("No Hugging Face API token found. Some models may not be accessible.")
92
+
93
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token=huggingface_token)
94
+ model = AutoModelForCausalLM.from_pretrained(
95
+ "google/gemma-2-2b-it",
96
+ token=huggingface_token,
97
+ torch_dtype=torch.float16,
98
+ device_map="auto"
99
+ )
100
+ return tokenizer, model
101
+
102
+ def generate_text(prompt, max_new_tokens=300, temperature=0.7):
103
+ tokenizer, model = load_model()
104
+
105
+ # Format the prompt according to Gemma's expected format
106
+ formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
107
+
108
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
109
+
110
+ # Create the progress bar
111
+ progress_bar = st.progress(0)
112
+ status_text = st.empty()
113
+ output_area = st.empty()
114
+
115
+ tokens_generated = 0
116
+ generated_text = ""
117
+
118
+ # Generate with streaming
119
+ streamer_output = ""
120
+
121
+ # Generate with step-by-step tracking for the progress bar
122
+ generate_kwargs = dict(
123
+ inputs=inputs["input_ids"],
124
+ max_new_tokens=max_new_tokens,
125
+ temperature=temperature,
126
+ do_sample=True,
127
+ pad_token_id=tokenizer.eos_token_id
128
+ )
129
+
130
+ status_text.text("Generating response...")
131
+
132
+ with torch.no_grad():
133
+ # Generate text step by step
134
+ for i in range(max_new_tokens):
135
+ if i == 0:
136
+ outputs = model.generate(
137
+ **generate_kwargs,
138
+ max_new_tokens=1,
139
+ )
140
+ generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
141
+ else:
142
+ input_ids = torch.cat([inputs["input_ids"], generated_ids], dim=1)
143
+ outputs = model.generate(
144
+ input_ids=input_ids,
145
+ max_new_tokens=1,
146
+ do_sample=True,
147
+ temperature=temperature,
148
+ pad_token_id=tokenizer.eos_token_id
149
+ )
150
+ new_token = outputs[0][-1].unsqueeze(0)
151
+ generated_ids = torch.cat([generated_ids, new_token], dim=0)
152
+
153
+ # Decode text
154
+ current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
155
+
156
+ # Update streaming output
157
+ streamer_output = current_text
158
+
159
+ # Update progress and output
160
+ progress = min(1.0, (i + 1) / max_new_tokens)
161
+ progress_bar.progress(progress)
162
+
163
+ # Update display
164
+ output_area.markdown(f"**Generated Response:**\n\n{streamer_output}")
165
+
166
+ # Check if we've reached an end token
167
+ if generated_ids[-1].item() == tokenizer.eos_token_id:
168
+ break
169
+
170
+ # Add a small delay to simulate typing
171
+ time.sleep(0.01)
172
+
173
+ status_text.text("Generation complete!")
174
+ progress_bar.progress(1.0)
175
+
176
+ return streamer_output
177
+
178
+ # Generate button
179
+ if st.button("Generate Text"):
180
+ if user_input:
181
+ st.session_state.user_prompt = user_input
182
+ with st.spinner("Generating text..."):
183
+ st.session_state.generated_text = generate_text(user_input, max_length, temperature)
184
+ st.session_state.generation_complete = True
185
+ else:
186
+ st.error("Please enter a prompt first!")
187
+
188
+ # Display results
189
+ if st.session_state.generation_complete:
190
+ st.markdown("### Generated Text")
191
+ st.markdown(st.session_state.generated_text)
192
+
193
+ # Analysis section
194
+ with st.expander("Text Analysis"):
195
+ col1, col2 = st.columns(2)
196
+ with col1:
197
+ st.metric("Character Count", len(st.session_state.generated_text))
198
+ st.metric("Word Count", len(st.session_state.generated_text.split()))
199
+ with col2:
200
+ st.metric("Sentence Count", st.session_state.generated_text.count('.') +
201
+ st.session_state.generated_text.count('!') +
202
+ st.session_state.generated_text.count('?'))
203
+ st.metric("Paragraph Count", st.session_state.generated_text.count('\n\n') + 1)
204
+
205
+ # Footer
206
+ st.markdown("---")
207
+ st.markdown("""
208
+ <div style="text-align: center">
209
+ <p>Created with ❤️ | Powered by Gemma 2-2B-IT and Hugging Face</p>
210
+ <p>Code available on <a href="https://huggingface.co/spaces/your-username/GemmaTextAppeal">Hugging Face Spaces</a></p>
211
+ </div>
212
+ """, unsafe_allow_html=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit==1.24.0
2
+ torch>=2.0.0
3
+ transformers>=4.31.0
4
+ python-dotenv==1.0.0
5
+ accelerate>=0.20.0