import subprocess subprocess.run( 'pip install flash-attn==2.7.0.post2 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True ) subprocess.run( 'pip install transformers', shell=True ) import spaces import os import re import logging from typing import List from threading import Thread import base64 import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer # ---------------------------------------------------------------------- # 1. Setup Model & Tokenizer # ---------------------------------------------------------------------- model_name = 'smirki/UIGEN-T1.1-Qwen-14B' # Change as needed use_thread = True # Generation happens in a background thread model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True ).to("cuda") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) logging.getLogger("httpx").setLevel(logging.WARNING) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ---------------------------------------------------------------------- # 2. Two-Phase Prompt Templates # ---------------------------------------------------------------------- s1_inference_prompt_think_only = """<|im_start|>user {question}<|im_end|> <|im_start|>assistant <|im_start|>think """ # ---------------------------------------------------------------------- # 3. Generation Parameter Setup # ---------------------------------------------------------------------- THINK_MAX_NEW_TOKENS = 12000 ANSWER_MAX_NEW_TOKENS = 12000 def initialize_gen_kwargs(): return { "max_new_tokens": 1024, # default; will be overwritten per phase "do_sample": True, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.05, # "eos_token_id": model.generation_config.eos_token_id, # Removed to avoid premature stopping "pad_token_id": tokenizer.pad_token_id, "use_cache": True, "streamer": None # dynamically added } # ---------------------------------------------------------------------- # 4. Helper to submit chat # ---------------------------------------------------------------------- def submit_chat(chatbot, text_input): if not text_input.strip(): return chatbot, "" response = "" chatbot.append((text_input, response)) return chatbot, "" # ---------------------------------------------------------------------- # 5. Artifacts Handling # We parse code from the final answer and display it in an iframe # ---------------------------------------------------------------------- def extract_html_code_block(text: str) -> str: """ Look for a ```html ... ``` block in the text. If found, return only that block content. Otherwise, return the entire text. """ pattern = r'```html\s*(.*?)\s*```' match = re.search(pattern, text, re.DOTALL) if match: return match.group(1).strip() else: return text.strip() def send_to_sandbox(html_code: str) -> str: """ Convert the code to a data URI iframe so it can be rendered inside Gradio HTML component. """ encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8') data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}" return f'' # ---------------------------------------------------------------------- # 6. The Two-Phase Streaming Inference # - Phase 1: "think" (chain-of-thought) # - Phase 2: "answer" # ---------------------------------------------------------------------- @spaces.GPU def ovis_chat(chatbot: List[List[str]]): # Phase 1: chain-of-thought last_query = chatbot[-1][0] formatted_think_prompt = s1_inference_prompt_think_only.format(question=last_query) input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt").to(model.device) attention_mask_think = torch.ne(input_ids_think, tokenizer.pad_token_id).to(model.device) think_inputs = { "input_ids": input_ids_think, "attention_mask": attention_mask_think } gen_kwargs_think = initialize_gen_kwargs() gen_kwargs_think["max_new_tokens"] = THINK_MAX_NEW_TOKENS think_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs_think["streamer"] = think_streamer full_think = "" with torch.inference_mode(): thread_think = Thread(target=lambda: model.generate(**think_inputs, **gen_kwargs_think)) thread_think.start() for new_text in think_streamer: full_think += new_text display_text = f"<|im_start|>think\n{full_think.strip()}" chatbot[-1][1] = display_text yield chatbot, "" # second return is artifact placeholder thread_think.join() # Phase 2: answer new_prompt = formatted_think_prompt + full_think.strip() + "\n<|im_start|>answer\n" input_ids_answer = tokenizer.encode(new_prompt, return_tensors="pt").to(model.device) attention_mask_answer = torch.ne(input_ids_answer, tokenizer.pad_token_id).to(model.device) answer_inputs = { "input_ids": input_ids_answer, "attention_mask": attention_mask_answer } gen_kwargs_answer = initialize_gen_kwargs() gen_kwargs_answer["max_new_tokens"] = ANSWER_MAX_NEW_TOKENS answer_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs_answer["streamer"] = answer_streamer full_answer = "" with torch.inference_mode(): thread_answer = Thread(target=lambda: model.generate(**answer_inputs, **gen_kwargs_answer)) thread_answer.start() for new_text in answer_streamer: full_answer += new_text display_text = ( f"<|im_start|>think\n{full_think.strip()}\n\n" f"<|im_start|>answer\n{full_answer.strip()}" ) chatbot[-1][1] = display_text yield chatbot, "" thread_answer.join() log_conversation(chatbot) # Once final answer is complete, parse out HTML code block and # return it as an artifact (iframe). html_code = extract_html_code_block(full_answer) sandbox_iframe = send_to_sandbox(html_code) yield chatbot, sandbox_iframe # ---------------------------------------------------------------------- # 7. Logging and Clearing # ---------------------------------------------------------------------- def log_conversation(chatbot: List[List[str]]): logger.info("[CONVERSATION]") for i, (query, response) in enumerate(chatbot, 1): logger.info(f"Q{i}: {query}\nA{i}: {response}") def clear_chat(): return [], "", "" # ---------------------------------------------------------------------- # 8. Gradio UI Setup # ---------------------------------------------------------------------- css_code = """ .left_header { display: flex; flex-direction: column; justify-content: center; align-items: center; } .right_panel { margin-top: 16px; border: 1px solid #BFBFC4; border-radius: 8px; overflow: hidden; } .render_header { height: 30px; width: 100%; padding: 5px 16px; background-color: #f5f5f5; } .header_btn { display: inline-block; height: 10px; width: 10px; border-radius: 50%; margin-right: 4px; } .render_header > .header_btn:nth-child(1) { background-color: #f5222d; } .render_header > .header_btn:nth-child(2) { background-color: #faad14; } .render_header > .header_btn:nth-child(3) { background-color: #52c41a; } .right_content { height: 920px; display: flex; flex-direction: column; justify-content: center; align-items: center; } .html_content { width: 100%; height: 920px; } """ svg_content = """ """ with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo: gr.HTML(f"""
{svg_content}

{model_name.split('/')[-1]} - Chat + Artifacts

(Two-phase chain-of-thought with artifact extraction)

""") with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( label="Chat", height=520, show_copy_button=True ) with gr.Row(): text_input = gr.Textbox( label="Prompt", placeholder="Enter your query...", lines=1 ) with gr.Row(): submit_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(scale=6): gr.HTML('
') artifact_html = gr.HTML( value="", elem_classes="html_content" ) submit_btn.click( submit_chat, [chatbot, text_input], [chatbot, text_input] ).then( ovis_chat, [chatbot], [chatbot, artifact_html] ) text_input.submit( submit_chat, [chatbot, text_input], [chatbot, text_input] ).then( ovis_chat, [chatbot], [chatbot, artifact_html] ) clear_btn.click( clear_chat, outputs=[chatbot, text_input, artifact_html] ) demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", share=True)