Luigi commited on
Commit
d181b45
·
1 Parent(s): 5dd835a

Apply ZeroGPU

Browse files
Files changed (2) hide show
  1. app.py +71 -130
  2. requirements.txt +3 -2
app.py CHANGED
@@ -1,13 +1,12 @@
1
  import os
2
  import time
3
- import re
4
  import gc
5
  import threading
6
  from itertools import islice
7
  from datetime import datetime
8
  import gradio as gr
9
- from llama_cpp import Llama
10
- from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
11
  from huggingface_hub import hf_hub_download
12
  from duckduckgo_search import DDGS
13
 
@@ -17,126 +16,77 @@ from duckduckgo_search import DDGS
17
  cancel_event = threading.Event()
18
 
19
  # ------------------------------
20
- # Model Definitions and Global Variables
 
 
 
 
21
  # ------------------------------
22
- REQUIRED_SPACE_BYTES = 5 * 1024 ** 3 # 5 GB
23
-
24
  MODELS = {
25
  "Taiwan-tinyllama-v1.0-chat (Q8_0)": {
26
- "repo_id": "NapYang/DavidLanz-Taiwan-tinyllama-v1.0-chat.GGUF",
27
- "filename": "Taiwan-tinyllama-v1.0-chat-Q8_0.gguf",
28
- "description": "Taiwan-tinyllama-v1.0-chat (Q8_0)"
29
  },
30
  "Llama-3.2-Taiwan-3B-Instruct (Q4_K_M)": {
31
- "repo_id": "itlwas/Llama-3.2-Taiwan-3B-Instruct-Q4_K_M-GGUF",
32
- "filename": "llama-3.2-taiwan-3b-instruct-q4_k_m.gguf",
33
- "description": "Llama-3.2-Taiwan-3B-Instruct (Q4_K_M)"
34
  },
35
  "MiniCPM3-4B (Q4_K_M)": {
36
- "repo_id": "openbmb/MiniCPM3-4B-GGUF",
37
- "filename": "minicpm3-4b-q4_k_m.gguf",
38
- "description": "MiniCPM3-4B (Q4_K_M)"
39
  },
40
  "Qwen2.5-3B-Instruct (Q4_K_M)": {
41
- "repo_id": "Qwen/Qwen2.5-3B-Instruct-GGUF",
42
- "filename": "qwen2.5-3b-instruct-q4_k_m.gguf",
43
- "description": "Qwen2.5-3B-Instruct (Q4_K_M)"
44
  },
45
  "Qwen2.5-7B-Instruct (Q2_K)": {
46
- "repo_id": "Qwen/Qwen2.5-7B-Instruct-GGUF",
47
- "filename": "qwen2.5-7b-instruct-q2_k.gguf",
48
- "description": "Qwen2.5-7B Instruct (Q2_K)"
49
  },
50
  "Gemma-3-4B-IT (Q4_K_M)": {
51
- "repo_id": "unsloth/gemma-3-4b-it-GGUF",
52
- "filename": "gemma-3-4b-it-Q4_K_M.gguf",
53
- "description": "Gemma 3 4B IT (Q4_K_M)"
54
  },
55
  "Phi-4-mini-Instruct (Q4_K_M)": {
56
- "repo_id": "unsloth/Phi-4-mini-instruct-GGUF",
57
- "filename": "Phi-4-mini-instruct-Q4_K_M.gguf",
58
- "description": "Phi-4 Mini Instruct (Q4_K_M)"
59
  },
60
  "Meta-Llama-3.1-8B-Instruct (Q2_K)": {
61
- "repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct-GGUF",
62
- "filename": "Meta-Llama-3.1-8B-Instruct.Q2_K.gguf",
63
- "description": "Meta-Llama-3.1-8B-Instruct (Q2_K)"
64
  },
65
  "DeepSeek-R1-Distill-Llama-8B (Q2_K)": {
66
- "repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
67
- "filename": "DeepSeek-R1-Distill-Llama-8B-Q2_K.gguf",
68
- "description": "DeepSeek-R1-Distill-Llama-8B (Q2_K)"
69
  },
70
  "Mistral-7B-Instruct-v0.3 (IQ3_XS)": {
71
- "repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF",
72
- "filename": "Mistral-7B-Instruct-v0.3.IQ3_XS.gguf",
73
- "description": "Mistral-7B-Instruct-v0.3 (IQ3_XS)"
74
  },
75
  "Qwen2.5-Coder-7B-Instruct (Q2_K)": {
76
- "repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct-GGUF",
77
- "filename": "qwen2.5-coder-7b-instruct-q2_k.gguf",
78
- "description": "Qwen2.5-Coder-7B-Instruct (Q2_K)"
79
  },
80
  }
81
 
 
82
  LOADED_MODELS = {}
83
  CURRENT_MODEL_NAME = None
84
 
85
  # ------------------------------
86
- # Model Loading Helper Functions
87
  # ------------------------------
88
- def try_load_model(model_path):
89
- try:
90
- return Llama(
91
- model_path=model_path,
92
- n_ctx=4096,
93
- n_threads=2,
94
- n_threads_batch=1,
95
- n_batch=256,
96
- n_gpu_layers=0,
97
- use_mlock=True,
98
- use_mmap=True,
99
- verbose=False,
100
- logits_all=True,
101
- draft_model=LlamaPromptLookupDecoding(num_pred_tokens=2),
102
- )
103
- except Exception as e:
104
- return str(e)
105
-
106
- def download_model(selected_model):
107
- hf_hub_download(
108
- repo_id=selected_model["repo_id"],
109
- filename=selected_model["filename"],
110
- local_dir="./models",
111
- local_dir_use_symlinks=False,
112
- )
113
-
114
- def validate_or_download_model(selected_model):
115
- model_path = os.path.join("models", selected_model["filename"])
116
- os.makedirs("models", exist_ok=True)
117
- if not os.path.exists(model_path):
118
- download_model(selected_model)
119
- result = try_load_model(model_path)
120
- if isinstance(result, str):
121
- try:
122
- os.remove(model_path)
123
- except Exception:
124
- pass
125
- download_model(selected_model)
126
- result = try_load_model(model_path)
127
- if isinstance(result, str):
128
- raise Exception(f"Model load failed: {result}")
129
- return result
130
-
131
  def load_model(model_name):
132
  global LOADED_MODELS, CURRENT_MODEL_NAME
133
  if model_name in LOADED_MODELS:
134
  return LOADED_MODELS[model_name]
135
  selected_model = MODELS[model_name]
136
- model = validate_or_download_model(selected_model)
137
- LOADED_MODELS[model_name] = model
 
 
138
  CURRENT_MODEL_NAME = model_name
139
- return model
140
 
141
  # ------------------------------
142
  # Web Search Context Retrieval Function
@@ -155,18 +105,10 @@ def retrieve_context(query, max_results=6, max_chars_per_result=600):
155
  return ""
156
 
157
  # ------------------------------
158
- # Chat Response Generation (Streaming) with Cancellation
159
  # ------------------------------
160
  def chat_response(user_message, chat_history, system_prompt, enable_search,
161
  max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
162
- """
163
- Generator function that:
164
- - Uses the chat history (list of dicts) from the Chatbot.
165
- - Appends the new user message.
166
- - Optionally retrieves web search context.
167
- - Streams the assistant response token-by-token.
168
- - Checks for cancellation.
169
- """
170
  # Reset the cancellation event.
171
  cancel_event.clear()
172
 
@@ -194,7 +136,7 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
194
  retrieved_context = ""
195
  debug_message = "Web search disabled."
196
 
197
- # Augment prompt.
198
  if enable_search and retrieved_context:
199
  augmented_user_input = (
200
  f"{system_prompt.strip()}\n\n"
@@ -205,41 +147,44 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
205
  else:
206
  augmented_user_input = f"{system_prompt.strip()}\n\nUser Query: {user_message}"
207
 
208
- # Build final prompt messages.
209
- messages = internal_history[:-1] + [{"role": "user", "content": augmented_user_input}]
210
-
211
- # Load the model.
212
- model = load_model(model_name)
213
-
214
- # Add an empty assistant message.
215
  internal_history.append({"role": "assistant", "content": ""})
216
- assistant_message = ""
217
 
218
  try:
219
- stream = model.create_chat_completion(
220
- messages=messages,
221
- max_tokens=max_tokens,
222
- temperature=temperature,
223
- top_k=top_k,
224
- top_p=top_p,
225
- repeat_penalty=repeat_penalty,
226
- stream=True,
227
- )
228
- for chunk in stream:
229
- # Check if a cancellation has been requested.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  if cancel_event.is_set():
231
  assistant_message += "\n\n[Response generation cancelled by user]"
232
  internal_history[-1]["content"] = assistant_message
233
  yield internal_history, debug_message
234
- break
235
-
236
- if "choices" in chunk:
237
- delta = chunk["choices"][0]["delta"].get("content", "")
238
- assistant_message += delta
239
- internal_history[-1]["content"] = assistant_message
240
- yield internal_history, debug_message
241
- if chunk["choices"][0].get("finish_reason", ""):
242
- break
243
  except Exception as e:
244
  internal_history[-1]["content"] = f"Error: {e}"
245
  yield internal_history, debug_message
@@ -255,8 +200,8 @@ def cancel_generation():
255
  # ------------------------------
256
  # Gradio UI Definition
257
  # ------------------------------
258
- with gr.Blocks(title="Multi-GGUF LLM Inference") as demo:
259
- gr.Markdown("## 🧠 Multi-GGUF LLM Inference with Web Search")
260
  gr.Markdown("Interact with the model. Select your model, set your system prompt, and adjust parameters on the left.")
261
 
262
  with gr.Row():
@@ -303,18 +248,14 @@ with gr.Blocks(title="Multi-GGUF LLM Inference") as demo:
303
  return [], "", ""
304
 
305
  clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
306
-
307
  cancel_button.click(fn=cancel_generation, outputs=search_debug)
308
 
309
- # Submission that returns conversation and debug info.
310
  msg_input.submit(
311
  fn=chat_response,
312
  inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
313
  max_results_number, max_chars_number, model_dropdown,
314
  max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repeat_penalty_slider],
315
  outputs=[chatbot, search_debug],
316
- # Uncomment streaming=True if supported.
317
- # streaming=True,
318
  )
319
 
320
  demo.launch()
 
1
  import os
2
  import time
 
3
  import gc
4
  import threading
5
  from itertools import islice
6
  from datetime import datetime
7
  import gradio as gr
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
  from huggingface_hub import hf_hub_download
11
  from duckduckgo_search import DDGS
12
 
 
16
  cancel_event = threading.Event()
17
 
18
  # ------------------------------
19
+ # Model Definitions and Global Variables (PyTorch/Transformers)
20
+ # ------------------------------
21
+ # Here, the repo_id should point to a model checkpoint that is compatible with Hugging Face Transformers.
22
+ # ------------------------------
23
+ # Torch-Compatible Model Definitions with Adjusted Descriptions
24
  # ------------------------------
 
 
25
  MODELS = {
26
  "Taiwan-tinyllama-v1.0-chat (Q8_0)": {
27
+ "repo_id": "DavidLanz/Taiwan-tinyllama-v1.0-chat",
28
+ "description": "Taiwan-tinyllama-v1.0-chat (Q8_0) – Torch-compatible version converted from GGUF."
 
29
  },
30
  "Llama-3.2-Taiwan-3B-Instruct (Q4_K_M)": {
31
+ "repo_id": "https://huggingface.co/lianghsun/Llama-3.2-Taiwan-3B-Instruct",
32
+ "description": "Llama-3.2-Taiwan-3B-Instruct (Q4_K_M) – Torch-compatible version converted from GGUF."
 
33
  },
34
  "MiniCPM3-4B (Q4_K_M)": {
35
+ "repo_id": "openbmb/MiniCPM3-4B",
36
+ "description": "MiniCPM3-4B (Q4_K_M) – Torch-compatible version converted from GGUF."
 
37
  },
38
  "Qwen2.5-3B-Instruct (Q4_K_M)": {
39
+ "repo_id": "Qwen/Qwen2.5-3B-Instruct",
40
+ "description": "Qwen2.5-3B-Instruct (Q4_K_M) – Torch-compatible version converted from GGUF."
 
41
  },
42
  "Qwen2.5-7B-Instruct (Q2_K)": {
43
+ "repo_id": "Qwen/Qwen2.5-7B-Instruct",
44
+ "description": "Qwen2.5-7B-Instruct (Q2_K) – Torch-compatible version converted from GGUF."
 
45
  },
46
  "Gemma-3-4B-IT (Q4_K_M)": {
47
+ "repo_id": "unsloth/gemma-3-4b-it",
48
+ "description": "Gemma-3-4B-IT (Q4_K_M) – Torch-compatible version converted from GGUF."
 
49
  },
50
  "Phi-4-mini-Instruct (Q4_K_M)": {
51
+ "repo_id": "unsloth/Phi-4-mini-instruct",
52
+ "description": "Phi-4-mini-Instruct (Q4_K_M) – Torch-compatible version converted from GGUF."
 
53
  },
54
  "Meta-Llama-3.1-8B-Instruct (Q2_K)": {
55
+ "repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct",
56
+ "description": "Meta-Llama-3.1-8B-Instruct (Q2_K) – Torch-compatible version converted from GGUF."
 
57
  },
58
  "DeepSeek-R1-Distill-Llama-8B (Q2_K)": {
59
+ "repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B",
60
+ "description": "DeepSeek-R1-Distill-Llama-8B (Q2_K) – Torch-compatible version converted from GGUF."
 
61
  },
62
  "Mistral-7B-Instruct-v0.3 (IQ3_XS)": {
63
+ "repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3",
64
+ "description": "Mistral-7B-Instruct-v0.3 (IQ3_XS) – Torch-compatible version converted from GGUF."
 
65
  },
66
  "Qwen2.5-Coder-7B-Instruct (Q2_K)": {
67
+ "repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct",
68
+ "description": "Qwen2.5-Coder-7B-Instruct (Q2_K) – Torch-compatible version converted from GGUF."
 
69
  },
70
  }
71
 
72
+
73
  LOADED_MODELS = {}
74
  CURRENT_MODEL_NAME = None
75
 
76
  # ------------------------------
77
+ # Model Loading Helper Function (PyTorch/Transformers)
78
  # ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def load_model(model_name):
80
  global LOADED_MODELS, CURRENT_MODEL_NAME
81
  if model_name in LOADED_MODELS:
82
  return LOADED_MODELS[model_name]
83
  selected_model = MODELS[model_name]
84
+ # Load both the model and tokenizer using the Transformers library.
85
+ model = AutoModelForCausalLM.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
86
+ tokenizer = AutoTokenizer.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
87
+ LOADED_MODELS[model_name] = (model, tokenizer)
88
  CURRENT_MODEL_NAME = model_name
89
+ return model, tokenizer
90
 
91
  # ------------------------------
92
  # Web Search Context Retrieval Function
 
105
  return ""
106
 
107
  # ------------------------------
108
+ # Chat Response Generation (Simulated Streaming) with Cancellation
109
  # ------------------------------
110
  def chat_response(user_message, chat_history, system_prompt, enable_search,
111
  max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
 
 
 
 
 
 
 
 
112
  # Reset the cancellation event.
113
  cancel_event.clear()
114
 
 
136
  retrieved_context = ""
137
  debug_message = "Web search disabled."
138
 
139
+ # Augment prompt with search context if available.
140
  if enable_search and retrieved_context:
141
  augmented_user_input = (
142
  f"{system_prompt.strip()}\n\n"
 
147
  else:
148
  augmented_user_input = f"{system_prompt.strip()}\n\nUser Query: {user_message}"
149
 
150
+ # Append a placeholder for the assistant's response.
 
 
 
 
 
 
151
  internal_history.append({"role": "assistant", "content": ""})
 
152
 
153
  try:
154
+ # Load the PyTorch model and tokenizer.
155
+ model, tokenizer = load_model(model_name)
156
+
157
+ # Tokenize the input prompt.
158
+ input_ids = tokenizer(augmented_user_input, return_tensors="pt").input_ids
159
+ with torch.no_grad():
160
+ output_ids = model.generate(
161
+ input_ids,
162
+ max_new_tokens=max_tokens,
163
+ temperature=temperature,
164
+ top_k=top_k,
165
+ top_p=top_p,
166
+ repetition_penalty=repeat_penalty,
167
+ do_sample=True
168
+ )
169
+
170
+ # Decode the generated tokens.
171
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
172
+ # Strip the original prompt to isolate the assistant’s reply.
173
+ assistant_text = generated_text[len(augmented_user_input):].strip()
174
+
175
+ # Simulate streaming by yielding the output word by word.
176
+ words = assistant_text.split()
177
+ assistant_message = ""
178
+ for word in words:
179
  if cancel_event.is_set():
180
  assistant_message += "\n\n[Response generation cancelled by user]"
181
  internal_history[-1]["content"] = assistant_message
182
  yield internal_history, debug_message
183
+ return
184
+ assistant_message += word + " "
185
+ internal_history[-1]["content"] = assistant_message
186
+ yield internal_history, debug_message
187
+ time.sleep(0.05) # Short delay to simulate streaming
 
 
 
 
188
  except Exception as e:
189
  internal_history[-1]["content"] = f"Error: {e}"
190
  yield internal_history, debug_message
 
200
  # ------------------------------
201
  # Gradio UI Definition
202
  # ------------------------------
203
+ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
204
+ gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search")
205
  gr.Markdown("Interact with the model. Select your model, set your system prompt, and adjust parameters on the left.")
206
 
207
  with gr.Row():
 
248
  return [], "", ""
249
 
250
  clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
 
251
  cancel_button.click(fn=cancel_generation, outputs=search_debug)
252
 
 
253
  msg_input.submit(
254
  fn=chat_response,
255
  inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
256
  max_results_number, max_chars_number, model_dropdown,
257
  max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repeat_penalty_slider],
258
  outputs=[chatbot, search_debug],
 
 
259
  )
260
 
261
  demo.launch()
requirements.txt CHANGED
@@ -5,7 +5,8 @@
5
  wheel
6
  jieba
7
  docopt
8
- llama-cpp-python --no-binary=:all: --global-option=build_ext --global-option="--cmake-args=-DGGML_CUDA=on"
9
  streamlit
10
  duckduckgo_search
11
- gradio
 
 
 
5
  wheel
6
  jieba
7
  docopt
 
8
  streamlit
9
  duckduckgo_search
10
+ gradio
11
+ torch
12
+ transformers