Spaces:
Running
on
Zero
Running
on
Zero
from transformers import MllamaForConditionalGeneration, AutoProcessor, TextIteratorStreamer | |
from PIL import Image | |
import requests | |
import torch | |
from threading import Thread | |
import gradio as gr | |
from gradio import FileData | |
import time | |
import spaces | |
ckpt = "meta-llama/Llama-3.2-11B-Vision-Instruct" | |
model = MllamaForConditionalGeneration.from_pretrained(ckpt, | |
torch_dtype=torch.bfloat16).to("cuda") | |
processor = AutoProcessor.from_pretrained(ckpt) | |
SYSTEM_PROMPT = """ You are a Vision Language Model specialized in visual document analysis. Your task is to analyze visual data and accurately answer user queries using a Chain-of-Thought (COT) approach. Self-reflection and error correction are crucial. | |
**Reasoning Process:** | |
1. **Initial Reasoning:** | |
* Use `<Thinking>` to describe your initial understanding, identify relevant sections, and generate a preliminary answer. | |
2. **Reflection and Error Check:** | |
* Use `<Reflection>` to critically examine your initial reasoning: section relevance, data accuracy, and alternative interpretations. Identify any potential errors. | |
3. **Refinement and Correction:** | |
* Use `<Correction>` to detail any corrections to your approach and why. Refine your answer. If no corrections needed, state "No correction needed". | |
4. **Final Answer:** | |
* Present your final answer in this format: | |
**Reasoning Steps:** | |
1. **Identification:** Briefly identify relevant document sections. | |
2. **Extraction:** State extracted visual/textual features. | |
3. **Synthesis:** Explain how extracted data led to the answer. | |
**Answer:** [Your detailed, accurate answer here] | |
**Requirements:** | |
* Use the COT structure and tags (`<Thinking>`, `<Reflection>`, `<Correction>`). | |
* Provide accurate, succinct answers. | |
* Always perform self-reflection and error correction. | |
* No corrections need to be clearly indicated""" | |
def bot_streaming(message, history, max_new_tokens=4048): | |
txt = message["text"] | |
messages = [{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}] | |
images = [] | |
for i, msg in enumerate(history): | |
if isinstance(msg[0], tuple): | |
messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]}) | |
messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]}) | |
images.append(Image.open(msg[0][0]).convert("RGB")) | |
elif isinstance(history[i-1], tuple) and isinstance(msg[0], str): | |
pass | |
elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): | |
messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]}) | |
messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]}) | |
if len(message["files"]) == 1: | |
if isinstance(message["files"][0], str): | |
image = Image.open(message["files"][0]).convert("RGB") | |
else: | |
image = Image.open(message["files"][0]["path"]).convert("RGB") | |
images.append(image) | |
messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]}) | |
else: | |
messages.append({"role": "user", "content": [{"type": "text", "text": txt}]}) | |
texts = processor.apply_chat_template(messages, add_generation_prompt=True) | |
if images == []: | |
inputs = processor(text=texts, return_tensors="pt").to("cuda") | |
else: | |
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda") | |
streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
demo = gr.ChatInterface( | |
fn=bot_streaming, | |
title="Overthinking Llama", | |
examples=[ | |
[{"text": "Which era does this piece belong to? Give details about the era.", "files":["./examples/rococo.jpg"]}, 200], | |
[{"text": "Where do the droughts happen according to this diagram?", "files":["./examples/weather_events.png"]}, 250], | |
[{"text": "What happens when you take out white cat from this chain?", "files":["./examples/ai2d_test.jpg"]}, 250], | |
[{"text": "How long does it take from invoice date to due date? Be short and concise.", "files":["./examples/invoice.png"]}, 250], | |
[{"text": "Where to find this monument? Can you give me other recommendations around the area?", "files":["./examples/wat_arun.jpg"]}, 250], | |
], | |
textbox=gr.MultimodalTextbox(), | |
additional_inputs=[ | |
gr.Slider( | |
minimum=10, | |
maximum=500, | |
value=4048, | |
step=10, | |
label="Maximum number of new tokens to generate" | |
) | |
], | |
cache_examples=False, | |
description="Upload an invoice or timesheet , Ask a question and let the model overthink the Answer", | |
stop_btn="Stop Generation", | |
fill_height=True, | |
multimodal=True | |
) | |
demo.launch(debug=True) |