BryanBradfo commited on
Commit
104c3a4
·
1 Parent(s): 0574f0a

change code + readme

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +22 -12
  3. requirements.txt +6 -5
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: GemmaTextAppeal
3
- emoji: 👁
4
  colorFrom: purple
5
- colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.44.1
8
  app_file: app.py
 
1
  ---
2
  title: GemmaTextAppeal
3
+ emoji: 🩵
4
  colorFrom: purple
5
+ colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.44.1
8
  app_file: app.py
app.py CHANGED
@@ -38,12 +38,25 @@ def load_model():
38
  token=huggingface_token
39
  )
40
 
 
 
 
 
 
 
 
 
 
 
41
  model = AutoModelForCausalLM.from_pretrained(
42
  "google/gemma-2-2b-it",
43
- token=huggingface_token,
44
- torch_dtype=torch.float16,
45
- device_map="auto"
46
  )
 
 
 
 
 
47
  return tokenizer, model, None
48
  except Exception as e:
49
  return None, None, str(e)
@@ -168,8 +181,6 @@ def generate_text(prompt, max_new_tokens=300, temperature=0.7):
168
  pad_token_id=tokenizer.eos_token_id
169
  )
170
 
171
- st.write("Generation completed, processing output...")
172
-
173
  # Get only the generated part (exclude the prompt)
174
  new_tokens = output_ids[0][input_ids.shape[1]:]
175
  generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
@@ -225,7 +236,7 @@ if st.session_state.error_message:
225
  with st.expander("Debug Information"):
226
  st.write(f"Model loaded: {model is not None}")
227
  st.write(f"Tokenizer loaded: {tokenizer is not None}")
228
- st.write(f"Device mapping: {model.device_map if model else 'N/A'}")
229
  st.write(f"Hugging Face token set: {huggingface_token is not None}")
230
  if torch.cuda.is_available():
231
  st.write(f"CUDA available: True (Device count: {torch.cuda.device_count()})")
@@ -241,12 +252,11 @@ if st.button("Generate Text"):
241
  st.error("Hugging Face token is required! Please add your token as described above.")
242
  elif user_input:
243
  st.session_state.user_prompt = user_input
244
- st.write("Starting text generation...")
245
- result = generate_text(user_input, max_length, temperature)
246
- st.write(f"Generation result: {'Success' if result else 'Failed'}")
247
- if result is not None: # Only set if no error occurred
248
- st.session_state.generated_text = result
249
- st.session_state.generation_complete = True
250
  else:
251
  st.error("Please enter a prompt first!")
252
 
 
38
  token=huggingface_token
39
  )
40
 
41
+ # Load model - use CPU configuration if no GPU available
42
+ model_kwargs = {
43
+ "token": huggingface_token,
44
+ "torch_dtype": torch.float16
45
+ }
46
+
47
+ # Only add device_map if GPU is available
48
+ if torch.cuda.is_available():
49
+ model_kwargs["device_map"] = "auto"
50
+
51
  model = AutoModelForCausalLM.from_pretrained(
52
  "google/gemma-2-2b-it",
53
+ **model_kwargs
 
 
54
  )
55
+
56
+ # Move model to CPU if no GPU
57
+ if not torch.cuda.is_available():
58
+ model = model.to("cpu")
59
+
60
  return tokenizer, model, None
61
  except Exception as e:
62
  return None, None, str(e)
 
181
  pad_token_id=tokenizer.eos_token_id
182
  )
183
 
 
 
184
  # Get only the generated part (exclude the prompt)
185
  new_tokens = output_ids[0][input_ids.shape[1]:]
186
  generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
 
236
  with st.expander("Debug Information"):
237
  st.write(f"Model loaded: {model is not None}")
238
  st.write(f"Tokenizer loaded: {tokenizer is not None}")
239
+ st.write(f"Device: {model.device if model else 'N/A'}")
240
  st.write(f"Hugging Face token set: {huggingface_token is not None}")
241
  if torch.cuda.is_available():
242
  st.write(f"CUDA available: True (Device count: {torch.cuda.device_count()})")
 
252
  st.error("Hugging Face token is required! Please add your token as described above.")
253
  elif user_input:
254
  st.session_state.user_prompt = user_input
255
+ with st.spinner("Generating text..."):
256
+ result = generate_text(user_input, max_length, temperature)
257
+ if result is not None: # Only set if no error occurred
258
+ st.session_state.generated_text = result
259
+ st.session_state.generation_complete = True
 
260
  else:
261
  st.error("Please enter a prompt first!")
262
 
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
-
2
- streamlit==1.24.0
3
- torch>=2.0.0
4
- transformers>=4.34.0
5
  python-dotenv==1.0.0
6
- accelerate>=0.20.0
 
 
 
1
+ streamlit==1.30.0
2
+ torch==2.1.0
3
+ transformers==4.35.0
 
4
  python-dotenv==1.0.0
5
+ huggingface-hub==0.19.0
6
+ accelerate==0.23.0
7
+ protobuf==3.20.3