Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import gc | |
import threading | |
from itertools import islice | |
from datetime import datetime | |
import gradio as gr | |
import torch | |
from transformers import pipeline, TextIteratorStreamer | |
from duckduckgo_search import DDGS | |
import spaces # Import spaces early to enable ZeroGPU support | |
# Optional: Disable GPU visibility if you wish to force CPU usage | |
# os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
# ------------------------------ | |
# Global Cancellation Event | |
# ------------------------------ | |
cancel_event = threading.Event() | |
# ------------------------------ | |
# Torch-Compatible Model Definitions with Adjusted Descriptions | |
# ------------------------------ | |
MODELS = { | |
"Gemma-3-4B-IT": {"repo_id": "unsloth/gemma-3-4b-it", "description": "Gemma-3-4B-IT"}, | |
"SmolLM2-135M-Instruct-TaiwanChat": {"repo_id": "Luigi/SmolLM2-135M-Instruct-TaiwanChat", "description": "SmolLM2‑135M Instruct fine-tuned on TaiwanChat"}, | |
"SmolLM2-135M-Instruct": {"repo_id": "HuggingFaceTB/SmolLM2-135M-Instruct", "description": "Original SmolLM2‑135M Instruct"}, | |
"Llama-3.2-Taiwan-3B-Instruct": {"repo_id": "lianghsun/Llama-3.2-Taiwan-3B-Instruct", "description": "Llama-3.2-Taiwan-3B-Instruct"}, | |
"MiniCPM3-4B": {"repo_id": "openbmb/MiniCPM3-4B", "description": "MiniCPM3-4B"}, | |
"Qwen2.5-3B-Instruct": {"repo_id": "Qwen/Qwen2.5-3B-Instruct", "description": "Qwen2.5-3B-Instruct"}, | |
"Qwen2.5-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-7B-Instruct", "description": "Qwen2.5-7B-Instruct"}, | |
"Phi-4-mini-Instruct": {"repo_id": "unsloth/Phi-4-mini-instruct", "description": "Phi-4-mini-Instruct"}, | |
"Meta-Llama-3.1-8B-Instruct": {"repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct", "description": "Meta-Llama-3.1-8B-Instruct"}, | |
"DeepSeek-R1-Distill-Llama-8B": {"repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B", "description": "DeepSeek-R1-Distill-Llama-8B"}, | |
"Mistral-7B-Instruct-v0.3": {"repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3", "description": "Mistral-7B-Instruct-v0.3"}, | |
"Qwen2.5-Coder-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct", "description": "Qwen2.5-Coder-7B-Instruct"}, | |
} | |
# Global cache for pipelines to avoid re-loading. | |
PIPELINES = {} | |
def load_pipeline(model_name): | |
""" | |
Load and cache a transformers pipeline for text generation. | |
Tries bfloat16, falls back to float16 or float32 if unsupported. | |
""" | |
global PIPELINES | |
if model_name in PIPELINES: | |
return PIPELINES[model_name] | |
repo = MODELS[model_name]["repo_id"] | |
for dtype in (torch.bfloat16, torch.float16, torch.float32): | |
try: | |
pipe = pipeline( | |
task="text-generation", | |
model=repo, | |
tokenizer=repo, | |
trust_remote_code=True, | |
torch_dtype=dtype, | |
device_map="auto" | |
) | |
PIPELINES[model_name] = pipe | |
return pipe | |
except Exception: | |
continue | |
# Final fallback | |
pipe = pipeline( | |
task="text-generation", | |
model=repo, | |
tokenizer=repo, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
PIPELINES[model_name] = pipe | |
return pipe | |
def retrieve_context(query, max_results=6, max_chars=600): | |
""" | |
Retrieve search snippets from DuckDuckGo (runs in background). | |
Returns a list of result strings. | |
""" | |
try: | |
with DDGS() as ddgs: | |
return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}" | |
for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))] | |
except Exception: | |
return [] | |
def format_conversation(history, system_prompt): | |
""" | |
Flatten chat history and system prompt into a single string. | |
""" | |
prompt = system_prompt.strip() + "\n" | |
for msg in history: | |
if msg['role'] == 'user': | |
prompt += "User: " + msg['content'].strip() + "\n" | |
elif msg['role'] == 'assistant': | |
prompt += "Assistant: " + msg['content'].strip() + "\n" | |
else: | |
prompt += msg['content'].strip() + "\n" | |
if not prompt.strip().endswith("Assistant:"): | |
prompt += "Assistant: " | |
return prompt | |
def chat_response(user_msg, chat_history, system_prompt, | |
enable_search, max_results, max_chars, | |
model_name, max_tokens, temperature, | |
top_k, top_p, repeat_penalty): | |
""" | |
Generates streaming chat responses, optionally with background web search. | |
""" | |
cancel_event.clear() | |
history = list(chat_history or []) | |
history.append({'role': 'user', 'content': user_msg}) | |
# Launch web search if enabled | |
debug = '' | |
search_results = [] | |
if enable_search: | |
debug = 'Search task started.' | |
thread_search = threading.Thread( | |
target=lambda: search_results.extend( | |
retrieve_context(user_msg, int(max_results), int(max_chars)) | |
) | |
) | |
thread_search.daemon = True | |
thread_search.start() | |
else: | |
debug = 'Web search disabled.' | |
# Prepare assistant placeholder | |
history.append({'role': 'assistant', 'content': ''}) | |
try: | |
# merge any fetched search results into the system prompt | |
if search_results: | |
enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) | |
else: | |
enriched = system_prompt | |
# wait up to 1s for snippets, then replace debug with them | |
if enable_search: | |
thread_search.join(timeout=1.0) | |
if search_results: | |
debug = "### Search results merged into prompt\n\n" + "\n".join( | |
f"- {r}" for r in search_results | |
) | |
else: | |
debug = "*No web search results found.*" | |
# merge fetched snippets into the system prompt | |
if search_results: | |
enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) | |
else: | |
enriched = system_prompt | |
prompt = format_conversation(history, enriched) | |
pipe = load_pipeline(model_name) | |
streamer = TextIteratorStreamer(pipe.tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True) | |
gen_thread = threading.Thread( | |
target=pipe, | |
args=(prompt,), | |
kwargs={ | |
'max_new_tokens': max_tokens, | |
'temperature': temperature, | |
'top_k': top_k, | |
'top_p': top_p, | |
'repetition_penalty': repeat_penalty, | |
'streamer': streamer, | |
'return_full_text': False | |
} | |
) | |
gen_thread.start() | |
assistant_text = '' | |
for chunk in streamer: | |
if cancel_event.is_set(): | |
break | |
assistant_text += chunk | |
history[-1]['content'] = assistant_text | |
# Show debug only once | |
yield history, debug | |
gen_thread.join() | |
except Exception as e: | |
history[-1]['content'] = f"Error: {e}" | |
yield history, debug | |
finally: | |
gc.collect() | |
def cancel_generation(): | |
cancel_event.set() | |
return 'Generation cancelled.' | |
def update_default_prompt(enable_search): | |
today = datetime.now().strftime('%Y-%m-%d') | |
return f"You are a helpful assistant. Today is {today}." | |
# ------------------------------ | |
# Gradio UI | |
# ------------------------------ | |
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo: | |
gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search") | |
gr.Markdown("Interact with the model. Select parameters and chat below.") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0]) | |
search_chk = gr.Checkbox(label="Enable Web Search", value=True) | |
sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value)) | |
gr.Markdown("### Generation Parameters") | |
max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") | |
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
k = gr.Slider(1, 100, value=40, step=1, label="Top-K") | |
p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") | |
rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") | |
gr.Markdown("### Web Search Settings") | |
mr = gr.Number(value=6, precision=0, label="Max Results") | |
mc = gr.Number(value=600, precision=0, label="Max Chars/Result") | |
clr = gr.Button("Clear Chat") | |
cnl = gr.Button("Cancel Generation") | |
with gr.Column(scale=7): | |
chat = gr.Chatbot(type="messages") | |
txt = gr.Textbox(placeholder="Type your message and press Enter...") | |
dbg = gr.Markdown() | |
search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt) | |
clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg]) | |
cnl.click(fn=cancel_generation, outputs=dbg) | |
txt.submit(fn=chat_response, | |
inputs=[txt, chat, sys_prompt, search_chk, mr, mc, | |
model_dd, max_tok, temp, k, p, rp], | |
outputs=[chat, dbg]) | |
demo.launch() | |