Spaces:
Running
on
Zero
Running
on
Zero
import subprocess | |
subprocess.run( | |
'pip install flash-attn==2.7.0.post2 --no-build-isolation', | |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
shell=True | |
) | |
subprocess.run( | |
'pip install transformers', | |
shell=True | |
) | |
import spaces | |
import os | |
import re | |
import logging | |
from typing import List | |
from threading import Thread | |
import base64 | |
import torch | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
# ---------------------------------------------------------------------- | |
# 1. Setup Model & Tokenizer | |
# ---------------------------------------------------------------------- | |
model_name = 'smirki/UIGEN-T1.1-Qwen-14B' # Change as needed | |
use_thread = True # Generation happens in a background thread | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True | |
).to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
logging.getLogger("httpx").setLevel(logging.WARNING) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# ---------------------------------------------------------------------- | |
# 2. Two-Phase Prompt Templates | |
# ---------------------------------------------------------------------- | |
s1_inference_prompt_think_only = """<|im_start|>user | |
{question}<|im_end|> | |
<|im_start|>assistant | |
<|im_start|>think | |
""" | |
# ---------------------------------------------------------------------- | |
# 3. Generation Parameter Setup | |
# ---------------------------------------------------------------------- | |
THINK_MAX_NEW_TOKENS = 12000 | |
ANSWER_MAX_NEW_TOKENS = 12000 | |
def initialize_gen_kwargs(): | |
return { | |
"max_new_tokens": 1024, # default; will be overwritten per phase | |
"do_sample": True, | |
"temperature": 0.7, | |
"top_p": 0.9, | |
"repetition_penalty": 1.05, | |
# "eos_token_id": model.generation_config.eos_token_id, # Removed to avoid premature stopping | |
"pad_token_id": tokenizer.pad_token_id, | |
"use_cache": True, | |
"streamer": None # dynamically added | |
} | |
# ---------------------------------------------------------------------- | |
# 4. Helper to submit chat | |
# ---------------------------------------------------------------------- | |
def submit_chat(chatbot, text_input): | |
if not text_input.strip(): | |
return chatbot, "" | |
response = "" | |
chatbot.append((text_input, response)) | |
return chatbot, "" | |
# ---------------------------------------------------------------------- | |
# 5. Artifacts Handling | |
# We parse code from the final answer and display it in an iframe | |
# ---------------------------------------------------------------------- | |
def extract_html_code_block(text: str) -> str: | |
""" | |
Look for a ```html ... ``` block in the text. | |
If found, return only that block content. | |
Otherwise, return the entire text. | |
""" | |
pattern = r'```html\s*(.*?)\s*```' | |
match = re.search(pattern, text, re.DOTALL) | |
if match: | |
return match.group(1).strip() | |
else: | |
return text.strip() | |
def send_to_sandbox(html_code: str) -> str: | |
""" | |
Convert the code to a data URI iframe so it can be rendered | |
inside Gradio HTML component. | |
""" | |
encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8') | |
data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}" | |
return f'<iframe src="{data_uri}" width="100%" height="920px"></iframe>' | |
# ---------------------------------------------------------------------- | |
# 6. The Two-Phase Streaming Inference | |
# - Phase 1: "think" (chain-of-thought) | |
# - Phase 2: "answer" | |
# ---------------------------------------------------------------------- | |
def ovis_chat(chatbot: List[List[str]]): | |
# Phase 1: chain-of-thought | |
last_query = chatbot[-1][0] | |
formatted_think_prompt = s1_inference_prompt_think_only.format(question=last_query) | |
input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt").to(model.device) | |
attention_mask_think = torch.ne(input_ids_think, tokenizer.pad_token_id).to(model.device) | |
think_inputs = { | |
"input_ids": input_ids_think, | |
"attention_mask": attention_mask_think | |
} | |
gen_kwargs_think = initialize_gen_kwargs() | |
gen_kwargs_think["max_new_tokens"] = THINK_MAX_NEW_TOKENS | |
think_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs_think["streamer"] = think_streamer | |
full_think = "" | |
with torch.inference_mode(): | |
thread_think = Thread(target=lambda: model.generate(**think_inputs, **gen_kwargs_think)) | |
thread_think.start() | |
for new_text in think_streamer: | |
full_think += new_text | |
display_text = f"<|im_start|>think\n{full_think.strip()}" | |
chatbot[-1][1] = display_text | |
yield chatbot, "" # second return is artifact placeholder | |
thread_think.join() | |
# Phase 2: answer | |
new_prompt = formatted_think_prompt + full_think.strip() + "\n<|im_start|>answer\n" | |
input_ids_answer = tokenizer.encode(new_prompt, return_tensors="pt").to(model.device) | |
attention_mask_answer = torch.ne(input_ids_answer, tokenizer.pad_token_id).to(model.device) | |
answer_inputs = { | |
"input_ids": input_ids_answer, | |
"attention_mask": attention_mask_answer | |
} | |
gen_kwargs_answer = initialize_gen_kwargs() | |
gen_kwargs_answer["max_new_tokens"] = ANSWER_MAX_NEW_TOKENS | |
answer_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs_answer["streamer"] = answer_streamer | |
full_answer = "" | |
with torch.inference_mode(): | |
thread_answer = Thread(target=lambda: model.generate(**answer_inputs, **gen_kwargs_answer)) | |
thread_answer.start() | |
for new_text in answer_streamer: | |
full_answer += new_text | |
display_text = ( | |
f"<|im_start|>think\n{full_think.strip()}\n\n" | |
f"<|im_start|>answer\n{full_answer.strip()}" | |
) | |
chatbot[-1][1] = display_text | |
yield chatbot, "" | |
thread_answer.join() | |
log_conversation(chatbot) | |
# Once final answer is complete, parse out HTML code block and | |
# return it as an artifact (iframe). | |
html_code = extract_html_code_block(full_answer) | |
sandbox_iframe = send_to_sandbox(html_code) | |
yield chatbot, sandbox_iframe | |
# ---------------------------------------------------------------------- | |
# 7. Logging and Clearing | |
# ---------------------------------------------------------------------- | |
def log_conversation(chatbot: List[List[str]]): | |
logger.info("[CONVERSATION]") | |
for i, (query, response) in enumerate(chatbot, 1): | |
logger.info(f"Q{i}: {query}\nA{i}: {response}") | |
def clear_chat(): | |
return [], "", "" | |
# ---------------------------------------------------------------------- | |
# 8. Gradio UI Setup | |
# ---------------------------------------------------------------------- | |
css_code = """ | |
.left_header { | |
display: flex; | |
flex-direction: column; | |
justify-content: center; | |
align-items: center; | |
} | |
.right_panel { | |
margin-top: 16px; | |
border: 1px solid #BFBFC4; | |
border-radius: 8px; | |
overflow: hidden; | |
} | |
.render_header { | |
height: 30px; | |
width: 100%; | |
padding: 5px 16px; | |
background-color: #f5f5f5; | |
} | |
.header_btn { | |
display: inline-block; | |
height: 10px; | |
width: 10px; | |
border-radius: 50%; | |
margin-right: 4px; | |
} | |
.render_header > .header_btn:nth-child(1) { | |
background-color: #f5222d; | |
} | |
.render_header > .header_btn:nth-child(2) { | |
background-color: #faad14; | |
} | |
.render_header > .header_btn:nth-child(3) { | |
background-color: #52c41a; | |
} | |
.right_content { | |
height: 920px; | |
display: flex; | |
flex-direction: column; | |
justify-content: center; | |
align-items: center; | |
} | |
.html_content { | |
width: 100%; | |
height: 920px; | |
} | |
""" | |
svg_content = """ | |
<svg width="40" height="40" viewBox="0 0 45 45" fill="none" xmlns="http://www.w3.org/2000/svg"> | |
<circle cx="22.5" cy="22.5" r="22.5" fill="#5572F9"/> | |
<path d="M22.5 11.25L26.25 16.875H18.75L22.5 11.25Z" fill="white"/> | |
<path d="M22.5 33.75L26.25 28.125H18.75L22.5 33.75Z" fill="white"/> | |
<path d="M28.125 22.5L22.5 28.125L16.875 22.5L22.5 16.875L28.125 22.5Z" fill="white"/> | |
</svg> | |
""" | |
with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo: | |
gr.HTML(f""" | |
<div class="left_header" style="margin-bottom: 20px;"> | |
{svg_content} | |
<h1>{model_name.split('/')[-1]} - Chat + Artifacts</h1> | |
<p>(Two-phase chain-of-thought with artifact extraction)</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
label="Chat", | |
height=520, | |
show_copy_button=True | |
) | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter your query...", | |
lines=1 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Send", variant="primary") | |
clear_btn = gr.Button("Clear", variant="secondary") | |
with gr.Column(scale=6): | |
gr.HTML('<div class="render_header"><span class="header_btn"></span><span class="header_btn"></span><span class="header_btn"></span></div>') | |
artifact_html = gr.HTML( | |
value="", | |
elem_classes="html_content" | |
) | |
submit_btn.click( | |
submit_chat, [chatbot, text_input], [chatbot, text_input] | |
).then( | |
ovis_chat, [chatbot], [chatbot, artifact_html] | |
) | |
text_input.submit( | |
submit_chat, [chatbot, text_input], [chatbot, text_input] | |
).then( | |
ovis_chat, [chatbot], [chatbot, artifact_html] | |
) | |
clear_btn.click( | |
clear_chat, | |
outputs=[chatbot, text_input, artifact_html] | |
) | |
demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", share=True) | |