davidizzle commited on
Commit
8020e39
·
1 Parent(s): d8f007f

Model upgrade and GPU support

Browse files
Files changed (2) hide show
  1. README.md +2 -1
  2. app.py +12 -5
README.md CHANGED
@@ -23,4 +23,5 @@ An interactive [Streamlit](https://streamlit.io) app to test [DeepSeek](https://
23
 
24
  ```bash
25
  pip install -r requirements.txt
26
- streamlit run app.py
 
 
23
 
24
  ```bash
25
  pip install -r requirements.txt
26
+ streamlit run app.py
27
+ ```
app.py CHANGED
@@ -40,14 +40,17 @@ def load_model():
40
  # As Gemma is gated, we will show functionality of the demo using DeepSeek-R1-Distill-Qwen-1.5B model
41
  # model_id = "google/gemma-2b-it"
42
  # tokenizer = AutoTokenizer.from_pretrained(model_id, token=True)
43
- model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
 
44
  tokenizer = AutoTokenizer.from_pretrained(model_id)
45
  model = AutoModelForCausalLM.from_pretrained(
46
  model_id,
47
- device_map=None,
48
- torch_dtype=torch.float32
 
 
49
  )
50
- model.to("cpu")
51
  return tokenizer, model
52
 
53
  tokenizer, model = load_model()
@@ -95,7 +98,11 @@ if st.button("Generate"):
95
 
96
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
97
  with torch.no_grad():
98
- outputs = model.generate(**inputs, max_new_tokens=100, temperature=1.0, top_p=0.95)
 
 
 
 
99
 
100
  # Back to still
101
  # gif_html.markdown(
 
40
  # As Gemma is gated, we will show functionality of the demo using DeepSeek-R1-Distill-Qwen-1.5B model
41
  # model_id = "google/gemma-2b-it"
42
  # tokenizer = AutoTokenizer.from_pretrained(model_id, token=True)
43
+ # model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
44
+ model_id = "deepseek-ai/deepseek-llm-7b-chat"
45
  tokenizer = AutoTokenizer.from_pretrained(model_id)
46
  model = AutoModelForCausalLM.from_pretrained(
47
  model_id,
48
+ # device_map=None,
49
+ # torch_dtype=torch.float32
50
+ device_map="auto",
51
+ torch_dtype=torch.float16
52
  )
53
+ # model.to("cpu")
54
  return tokenizer, model
55
 
56
  tokenizer, model = load_model()
 
98
 
99
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
100
  with torch.no_grad():
101
+ outputs = model.generate( **inputs,
102
+ # max_new_tokens=100,
103
+ max_new_tokens=200,
104
+ temperature=1.0,
105
+ top_p=0.95)
106
 
107
  # Back to still
108
  # gif_html.markdown(