Spaces:
Sleeping
Sleeping
Commit
·
104c3a4
1
Parent(s):
0574f0a
change code + readme
Browse files- README.md +2 -2
- app.py +22 -12
- requirements.txt +6 -5
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: GemmaTextAppeal
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.44.1
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
title: GemmaTextAppeal
|
3 |
+
emoji: 🩵
|
4 |
colorFrom: purple
|
5 |
+
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.44.1
|
8 |
app_file: app.py
|
app.py
CHANGED
@@ -38,12 +38,25 @@ def load_model():
|
|
38 |
token=huggingface_token
|
39 |
)
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
model = AutoModelForCausalLM.from_pretrained(
|
42 |
"google/gemma-2-2b-it",
|
43 |
-
|
44 |
-
torch_dtype=torch.float16,
|
45 |
-
device_map="auto"
|
46 |
)
|
|
|
|
|
|
|
|
|
|
|
47 |
return tokenizer, model, None
|
48 |
except Exception as e:
|
49 |
return None, None, str(e)
|
@@ -168,8 +181,6 @@ def generate_text(prompt, max_new_tokens=300, temperature=0.7):
|
|
168 |
pad_token_id=tokenizer.eos_token_id
|
169 |
)
|
170 |
|
171 |
-
st.write("Generation completed, processing output...")
|
172 |
-
|
173 |
# Get only the generated part (exclude the prompt)
|
174 |
new_tokens = output_ids[0][input_ids.shape[1]:]
|
175 |
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
@@ -225,7 +236,7 @@ if st.session_state.error_message:
|
|
225 |
with st.expander("Debug Information"):
|
226 |
st.write(f"Model loaded: {model is not None}")
|
227 |
st.write(f"Tokenizer loaded: {tokenizer is not None}")
|
228 |
-
st.write(f"Device
|
229 |
st.write(f"Hugging Face token set: {huggingface_token is not None}")
|
230 |
if torch.cuda.is_available():
|
231 |
st.write(f"CUDA available: True (Device count: {torch.cuda.device_count()})")
|
@@ -241,12 +252,11 @@ if st.button("Generate Text"):
|
|
241 |
st.error("Hugging Face token is required! Please add your token as described above.")
|
242 |
elif user_input:
|
243 |
st.session_state.user_prompt = user_input
|
244 |
-
st.
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
st.session_state.generation_complete = True
|
250 |
else:
|
251 |
st.error("Please enter a prompt first!")
|
252 |
|
|
|
38 |
token=huggingface_token
|
39 |
)
|
40 |
|
41 |
+
# Load model - use CPU configuration if no GPU available
|
42 |
+
model_kwargs = {
|
43 |
+
"token": huggingface_token,
|
44 |
+
"torch_dtype": torch.float16
|
45 |
+
}
|
46 |
+
|
47 |
+
# Only add device_map if GPU is available
|
48 |
+
if torch.cuda.is_available():
|
49 |
+
model_kwargs["device_map"] = "auto"
|
50 |
+
|
51 |
model = AutoModelForCausalLM.from_pretrained(
|
52 |
"google/gemma-2-2b-it",
|
53 |
+
**model_kwargs
|
|
|
|
|
54 |
)
|
55 |
+
|
56 |
+
# Move model to CPU if no GPU
|
57 |
+
if not torch.cuda.is_available():
|
58 |
+
model = model.to("cpu")
|
59 |
+
|
60 |
return tokenizer, model, None
|
61 |
except Exception as e:
|
62 |
return None, None, str(e)
|
|
|
181 |
pad_token_id=tokenizer.eos_token_id
|
182 |
)
|
183 |
|
|
|
|
|
184 |
# Get only the generated part (exclude the prompt)
|
185 |
new_tokens = output_ids[0][input_ids.shape[1]:]
|
186 |
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
|
|
236 |
with st.expander("Debug Information"):
|
237 |
st.write(f"Model loaded: {model is not None}")
|
238 |
st.write(f"Tokenizer loaded: {tokenizer is not None}")
|
239 |
+
st.write(f"Device: {model.device if model else 'N/A'}")
|
240 |
st.write(f"Hugging Face token set: {huggingface_token is not None}")
|
241 |
if torch.cuda.is_available():
|
242 |
st.write(f"CUDA available: True (Device count: {torch.cuda.device_count()})")
|
|
|
252 |
st.error("Hugging Face token is required! Please add your token as described above.")
|
253 |
elif user_input:
|
254 |
st.session_state.user_prompt = user_input
|
255 |
+
with st.spinner("Generating text..."):
|
256 |
+
result = generate_text(user_input, max_length, temperature)
|
257 |
+
if result is not None: # Only set if no error occurred
|
258 |
+
st.session_state.generated_text = result
|
259 |
+
st.session_state.generation_complete = True
|
|
|
260 |
else:
|
261 |
st.error("Please enter a prompt first!")
|
262 |
|
requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
transformers>=4.34.0
|
5 |
python-dotenv==1.0.0
|
6 |
-
|
|
|
|
|
|
1 |
+
streamlit==1.30.0
|
2 |
+
torch==2.1.0
|
3 |
+
transformers==4.35.0
|
|
|
4 |
python-dotenv==1.0.0
|
5 |
+
huggingface-hub==0.19.0
|
6 |
+
accelerate==0.23.0
|
7 |
+
protobuf==3.20.3
|