|
import base64 |
|
import json |
|
import ast |
|
import os |
|
import re |
|
import io |
|
import math |
|
import gradio as gr |
|
import oss2 |
|
from oss2.credentials import EnvironmentVariableCredentialsProvider |
|
from openai import OpenAI |
|
from datetime import datetime |
|
from PIL import ImageDraw |
|
|
|
|
|
|
|
DESCRIPTION = "[UI-TARS](https://github.com/bytedance/UI-TARS)" |
|
client = OpenAI( |
|
base_url=os.environ.get("ENDPOINT_URL"), |
|
api_key=os.environ.get("API_KEY") |
|
) |
|
|
|
|
|
prompt = "Output only the coordinate of one box in your response. " |
|
auth = oss2.ProviderAuthV4(EnvironmentVariableCredentialsProvider()) |
|
endpoint = 'oss-us-east-1.aliyuncs.com' |
|
region = "us-east-1" |
|
bucket = os.environ.get("BUCKET") |
|
bucket = oss2.Bucket(auth, endpoint, bucket, region=region) |
|
|
|
|
|
def draw_point_area(image, point): |
|
radius = min(image.width, image.height) // 15 |
|
x, y = round(point[0]/1000 * image.width), round(point[1]/1000 * image.height) |
|
ImageDraw.Draw(image).ellipse((x - radius, y - radius, x + radius, y + radius), outline='red', width=2) |
|
ImageDraw.Draw(image).ellipse((x - 2, y - 2, x + 2, y + 2), fill='red') |
|
return image |
|
|
|
|
|
def resize_image(image): |
|
max_pixels = 6000 * 28 * 28 |
|
if image.width * image.height > max_pixels: |
|
max_pixels = 2700 * 28 * 28 |
|
else: |
|
max_pixels = 1340 * 28 * 28 |
|
resize_factor = math.sqrt(max_pixels / (image.width * image.height)) |
|
width, height = int(image.width * resize_factor), int(image.height * resize_factor) |
|
image = image.resize((width, height)) |
|
return image |
|
|
|
|
|
def upload_images(session_id, image, result_image, query): |
|
img_path = f"{session_id}.png" |
|
result_img_path = f"{session_id}-draw.png" |
|
metadata = dict( |
|
query=query, |
|
resize_image=img_path, |
|
result_image=result_img_path, |
|
session_id=session_id |
|
) |
|
img_bytes = io.BytesIO() |
|
image.save(img_bytes, format="png") |
|
img_bytes = img_bytes.getvalue() |
|
bucket.put_object(img_path, img_bytes) |
|
|
|
rst_img_bytes = io.BytesIO() |
|
result_image.save(rst_img_bytes, format="png") |
|
rst_img_bytes = rst_img_bytes.getvalue() |
|
bucket.put_object(result_img_path, rst_img_bytes) |
|
bucket.put_object(f"{session_id}.json", json.dumps(metadata)) |
|
print("end upload images") |
|
|
|
|
|
def run_ui(image, query, session_id, is_example_image): |
|
click_xy = None |
|
images_during_iterations = [] |
|
width, height = image.width, image.height |
|
image = resize_image(image) |
|
bytes = io.BytesIO() |
|
image.save(bytes, format="png") |
|
base64_image = base64.standard_b64encode(bytes.getvalue()).decode("utf-8") |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}, |
|
{"type": "text", "text": prompt + query}, |
|
], |
|
} |
|
] |
|
response = client.chat.completions.create(model="tgi", messages=messages, temperature=1.0, top_p=0.7, max_tokens=128, frequency_penalty=1, stream=False) |
|
output_text = response.choices[0].message.content |
|
pattern = r"\((\d+,\d+)\)" |
|
match = re.search(pattern, output_text) |
|
if match: |
|
coordinates = match.group(1) |
|
click_xy = ast.literal_eval(coordinates) |
|
result_image = draw_point_area(image, click_xy) |
|
images_during_iterations.append(result_image) |
|
click_xy = round(click_xy[0]/1000 * width), round(click_xy[1]/1000 * height) |
|
|
|
if is_example_image == "False": |
|
upload_images(session_id, image, result_image, query) |
|
|
|
return images_during_iterations, str(click_xy) |
|
|
|
|
|
def update_vote(vote_type, image, click_image, prompt, is_example): |
|
"""upload bad cases to somewhere""" |
|
if vote_type == "upvote": |
|
return "Everything good" |
|
|
|
if is_example == "True": |
|
return "Do nothing for example" |
|
click_img_path = click_image[0] |
|
image.size |
|
|
|
return f"Thank you for your feedback!" |
|
|
|
|
|
examples = [ |
|
["./examples/solitaire.png", "Play the solitaire collection", True], |
|
["./examples/weather_ui.png", "Open map", True], |
|
["./examples/football_live.png", "click team 1 win", True], |
|
["./examples/windows_panel.png", "switch to documents", True], |
|
["./examples/paint_3d.png", "rotate left", True], |
|
["./examples/finder.png", "view files from airdrop", True], |
|
["./examples/amazon.jpg", "Search bar at the top of the page", True], |
|
["./examples/semantic.jpg", "Home", True], |
|
["./examples/accweather.jpg", "Select May", True], |
|
["./examples/arxiv.jpg", "Home", True], |
|
["./examples/health.jpg", "text labeled by 2023/11/26", True], |
|
["./examples/ios_setting.png", "Turn off Do not disturb.", True], |
|
] |
|
|
|
|
|
|
|
title_markdown = (""" |
|
# UI-TARS Pioneering Automated GUI Interaction with Native Agents |
|
[[🤗Model](https://huggingface.co/bytedance-research/UI-TARS-7B-SFT)] [[⌨️Code](https://github.com/bytedance/UI-TARS)] [[📑Paper](https://github.com/bytedance/UI-TARS/blob/main/UI_TARS_paper.pdf)] [🏄[Midscene (Browser Automation)](https://github.com/web-infra-dev/Midscene)] [🫨[Discord](https://discord.gg/txAE43ps)] |
|
""") |
|
|
|
|
|
tos_markdown = (""" |
|
### Terms of use |
|
This demo is governed by the original license of UI-TARS. We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc. (注:本演示受UI-TARS的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。) |
|
""") |
|
|
|
|
|
learn_more_markdown = (""" |
|
### License |
|
Apache License 2.0 |
|
""") |
|
|
|
|
|
code_adapt_markdown = (""" |
|
### Acknowledgments |
|
The app code is modified from [ShowUI](https://huggingface.co/spaces/showlab/ShowUI) |
|
""") |
|
|
|
|
|
block_css = """ |
|
#buttons button { |
|
min-width: min(120px,100%); |
|
} |
|
|
|
#chatbot img { |
|
max-width: 80%; |
|
max-height: 80vh; |
|
width: auto; |
|
height: auto; |
|
object-fit: contain; |
|
} |
|
""" |
|
|
|
|
|
def build_demo(): |
|
with gr.Blocks(title="UI-TARS Demo", theme=gr.themes.Default(), css=block_css) as demo: |
|
state_session_id = gr.State(value=None) |
|
gr.Markdown(title_markdown) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
imagebox = gr.Image(type="pil", label="Input Screenshot") |
|
|
|
textbox = gr.Textbox( |
|
show_label=True, |
|
placeholder="Enter an instruction and press Submit", |
|
label="Instruction", |
|
) |
|
submit_btn = gr.Button(value="Submit", variant="primary") |
|
|
|
with gr.Column(scale=6): |
|
output_gallery = gr.Gallery(label="Output with click", object_fit="contain", preview=True) |
|
|
|
gr.HTML( |
|
""" |
|
<p><strong>Notice:</strong> The <span style="color: red;">red point</span> with a circle on the output image represents the predicted coordinates for a click.</p> |
|
""" |
|
) |
|
with gr.Row(): |
|
output_coords = gr.Textbox(label="Final Coordinates") |
|
image_size = gr.Textbox(label="Image Size") |
|
|
|
gr.HTML( |
|
""" |
|
<p><strong>Expected result or not? help us improve! ⬇️</strong></p> |
|
""" |
|
) |
|
with gr.Row(elem_id="action-buttons", equal_height=True): |
|
upvote_btn = gr.Button(value="👍 Looks good!", variant="secondary") |
|
downvote_btn = gr.Button(value="👎 Wrong coordinates!", variant="secondary") |
|
clear_btn = gr.Button(value="🗑️ Clear", interactive=True) |
|
|
|
with gr.Column(scale=3): |
|
gr.Examples( |
|
examples=[[e[0], e[1]] for e in examples], |
|
inputs=[imagebox, textbox], |
|
outputs=[textbox], |
|
examples_per_page=3, |
|
) |
|
|
|
is_example_dropdown = gr.Dropdown( |
|
choices=["True", "False"], |
|
value="False", |
|
visible=False, |
|
label="Is Example Image", |
|
) |
|
|
|
def set_is_example(query): |
|
for _, example_query, is_example in examples: |
|
if query.strip() == example_query.strip(): |
|
return str(is_example) |
|
return "False" |
|
|
|
textbox.change( |
|
set_is_example, |
|
inputs=[textbox], |
|
outputs=[is_example_dropdown], |
|
) |
|
|
|
def on_submit(image, query, is_example_image): |
|
if image is None: |
|
raise ValueError("No image provided. Please upload an image before submitting.") |
|
|
|
session_id = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
images_during_iterations, click_coords = run_ui(image, query, session_id, is_example_image) |
|
return images_during_iterations, click_coords, session_id, f"{image.width}x{image.height}" |
|
|
|
submit_btn.click( |
|
on_submit, |
|
[imagebox, textbox, is_example_dropdown], |
|
[output_gallery, output_coords, state_session_id, image_size], |
|
) |
|
|
|
clear_btn.click( |
|
lambda: (None, None, None, None, None, None), |
|
inputs=None, |
|
outputs=[imagebox, textbox, output_gallery, output_coords, state_session_id, image_size], |
|
queue=False |
|
) |
|
|
|
upvote_btn.click( |
|
lambda image, click_image, prompt, is_example: update_vote("upvote", image, click_image, prompt, is_example), |
|
inputs=[imagebox, output_gallery, textbox, is_example_dropdown], |
|
outputs=[], |
|
queue=False |
|
) |
|
|
|
downvote_btn.click( |
|
lambda image, click_image, prompt, is_example: update_vote("downvote", image, click_image, prompt, is_example), |
|
inputs=[imagebox, output_gallery, textbox, is_example_dropdown], |
|
outputs=[], |
|
queue=False |
|
) |
|
|
|
gr.Markdown(tos_markdown) |
|
gr.Markdown(learn_more_markdown) |
|
gr.Markdown(code_adapt_markdown) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = build_demo() |
|
demo.queue(api_open=False).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
debug=True, |
|
) |