smirki's picture
Update app.py
429232d verified
import subprocess
subprocess.run(
'pip install flash-attn==2.7.0.post2 --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True
)
subprocess.run(
'pip install transformers',
shell=True
)
import spaces
import os
import re
import logging
from typing import List
from threading import Thread
import base64
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# ----------------------------------------------------------------------
# 1. Setup Model & Tokenizer
# ----------------------------------------------------------------------
model_name = 'smirki/UIGEN-T1.1-Qwen-14B' # Change as needed
use_thread = True # Generation happens in a background thread
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------
# 2. Two-Phase Prompt Templates
# ----------------------------------------------------------------------
s1_inference_prompt_think_only = """<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
<|im_start|>think
"""
# ----------------------------------------------------------------------
# 3. Generation Parameter Setup
# ----------------------------------------------------------------------
THINK_MAX_NEW_TOKENS = 12000
ANSWER_MAX_NEW_TOKENS = 12000
def initialize_gen_kwargs():
return {
"max_new_tokens": 1024, # default; will be overwritten per phase
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.05,
# "eos_token_id": model.generation_config.eos_token_id, # Removed to avoid premature stopping
"pad_token_id": tokenizer.pad_token_id,
"use_cache": True,
"streamer": None # dynamically added
}
# ----------------------------------------------------------------------
# 4. Helper to submit chat
# ----------------------------------------------------------------------
def submit_chat(chatbot, text_input):
if not text_input.strip():
return chatbot, ""
response = ""
chatbot.append((text_input, response))
return chatbot, ""
# ----------------------------------------------------------------------
# 5. Artifacts Handling
# We parse code from the final answer and display it in an iframe
# ----------------------------------------------------------------------
def extract_html_code_block(text: str) -> str:
"""
Look for a ```html ... ``` block in the text.
If found, return only that block content.
Otherwise, return the entire text.
"""
pattern = r'```html\s*(.*?)\s*```'
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()
else:
return text.strip()
def send_to_sandbox(html_code: str) -> str:
"""
Convert the code to a data URI iframe so it can be rendered
inside Gradio HTML component.
"""
encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
return f'<iframe src="{data_uri}" width="100%" height="920px"></iframe>'
# ----------------------------------------------------------------------
# 6. The Two-Phase Streaming Inference
# - Phase 1: "think" (chain-of-thought)
# - Phase 2: "answer"
# ----------------------------------------------------------------------
@spaces.GPU
def ovis_chat(chatbot: List[List[str]]):
# Phase 1: chain-of-thought
last_query = chatbot[-1][0]
formatted_think_prompt = s1_inference_prompt_think_only.format(question=last_query)
input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt").to(model.device)
attention_mask_think = torch.ne(input_ids_think, tokenizer.pad_token_id).to(model.device)
think_inputs = {
"input_ids": input_ids_think,
"attention_mask": attention_mask_think
}
gen_kwargs_think = initialize_gen_kwargs()
gen_kwargs_think["max_new_tokens"] = THINK_MAX_NEW_TOKENS
think_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs_think["streamer"] = think_streamer
full_think = ""
with torch.inference_mode():
thread_think = Thread(target=lambda: model.generate(**think_inputs, **gen_kwargs_think))
thread_think.start()
for new_text in think_streamer:
full_think += new_text
display_text = f"<|im_start|>think\n{full_think.strip()}"
chatbot[-1][1] = display_text
yield chatbot, "" # second return is artifact placeholder
thread_think.join()
# Phase 2: answer
new_prompt = formatted_think_prompt + full_think.strip() + "\n<|im_start|>answer\n"
input_ids_answer = tokenizer.encode(new_prompt, return_tensors="pt").to(model.device)
attention_mask_answer = torch.ne(input_ids_answer, tokenizer.pad_token_id).to(model.device)
answer_inputs = {
"input_ids": input_ids_answer,
"attention_mask": attention_mask_answer
}
gen_kwargs_answer = initialize_gen_kwargs()
gen_kwargs_answer["max_new_tokens"] = ANSWER_MAX_NEW_TOKENS
answer_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs_answer["streamer"] = answer_streamer
full_answer = ""
with torch.inference_mode():
thread_answer = Thread(target=lambda: model.generate(**answer_inputs, **gen_kwargs_answer))
thread_answer.start()
for new_text in answer_streamer:
full_answer += new_text
display_text = (
f"<|im_start|>think\n{full_think.strip()}\n\n"
f"<|im_start|>answer\n{full_answer.strip()}"
)
chatbot[-1][1] = display_text
yield chatbot, ""
thread_answer.join()
log_conversation(chatbot)
# Once final answer is complete, parse out HTML code block and
# return it as an artifact (iframe).
html_code = extract_html_code_block(full_answer)
sandbox_iframe = send_to_sandbox(html_code)
yield chatbot, sandbox_iframe
# ----------------------------------------------------------------------
# 7. Logging and Clearing
# ----------------------------------------------------------------------
def log_conversation(chatbot: List[List[str]]):
logger.info("[CONVERSATION]")
for i, (query, response) in enumerate(chatbot, 1):
logger.info(f"Q{i}: {query}\nA{i}: {response}")
def clear_chat():
return [], "", ""
# ----------------------------------------------------------------------
# 8. Gradio UI Setup
# ----------------------------------------------------------------------
css_code = """
.left_header {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
}
.right_panel {
margin-top: 16px;
border: 1px solid #BFBFC4;
border-radius: 8px;
overflow: hidden;
}
.render_header {
height: 30px;
width: 100%;
padding: 5px 16px;
background-color: #f5f5f5;
}
.header_btn {
display: inline-block;
height: 10px;
width: 10px;
border-radius: 50%;
margin-right: 4px;
}
.render_header > .header_btn:nth-child(1) {
background-color: #f5222d;
}
.render_header > .header_btn:nth-child(2) {
background-color: #faad14;
}
.render_header > .header_btn:nth-child(3) {
background-color: #52c41a;
}
.right_content {
height: 920px;
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
}
.html_content {
width: 100%;
height: 920px;
}
"""
svg_content = """
<svg width="40" height="40" viewBox="0 0 45 45" fill="none" xmlns="http://www.w3.org/2000/svg">
<circle cx="22.5" cy="22.5" r="22.5" fill="#5572F9"/>
<path d="M22.5 11.25L26.25 16.875H18.75L22.5 11.25Z" fill="white"/>
<path d="M22.5 33.75L26.25 28.125H18.75L22.5 33.75Z" fill="white"/>
<path d="M28.125 22.5L22.5 28.125L16.875 22.5L22.5 16.875L28.125 22.5Z" fill="white"/>
</svg>
"""
with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo:
gr.HTML(f"""
<div class="left_header" style="margin-bottom: 20px;">
{svg_content}
<h1>{model_name.split('/')[-1]} - Chat + Artifacts</h1>
<p>(Two-phase chain-of-thought with artifact extraction)</p>
</div>
""")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
label="Chat",
height=520,
show_copy_button=True
)
with gr.Row():
text_input = gr.Textbox(
label="Prompt",
placeholder="Enter your query...",
lines=1
)
with gr.Row():
submit_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column(scale=6):
gr.HTML('<div class="render_header"><span class="header_btn"></span><span class="header_btn"></span><span class="header_btn"></span></div>')
artifact_html = gr.HTML(
value="",
elem_classes="html_content"
)
submit_btn.click(
submit_chat, [chatbot, text_input], [chatbot, text_input]
).then(
ovis_chat, [chatbot], [chatbot, artifact_html]
)
text_input.submit(
submit_chat, [chatbot, text_input], [chatbot, text_input]
).then(
ovis_chat, [chatbot], [chatbot, artifact_html]
)
clear_btn.click(
clear_chat,
outputs=[chatbot, text_input, artifact_html]
)
demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", share=True)