File size: 9,065 Bytes
1e0442b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5b1f88
1e0442b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e64071c
1e0442b
 
 
 
 
 
 
d944583
1e0442b
 
 
 
 
 
 
 
 
 
 
d944583
 
 
 
1e0442b
3deb6e9
d944583
c9f689e
1e0442b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58bd064
1e0442b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58bd064
1e0442b
3f95289
 
 
 
 
 
 
 
c9f689e
1e0442b
 
c9f689e
1e0442b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e64071c
1e0442b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import argparse
import os
import re
import threading
import time
from datetime import datetime, timedelta

import torch
from threading import Thread, Event
from PIL import Image, ImageDraw
import gradio as gr
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
)
from typing import List
import spaces

stop_event = Event()

def delete_old_files():
    while True:
        now = datetime.now()
        cutoff = now - timedelta(minutes=10)
        directories = ["./outputs", "./gradio_tmp"]

        for directory in directories:
            for filename in os.listdir(directory):
                file_path = os.path.join(directory, filename)
                if os.path.isfile(file_path):
                    file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
                    if file_mtime < cutoff:
                        os.remove(file_path)
        time.sleep(600)


threading.Thread(target=delete_old_files, daemon=True).start()


def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str):
    draw = ImageDraw.Draw(image)
    for box in boxes:
        x_min = int(box[0] * image.width)
        y_min = int(box[1] * image.height)
        x_max = int(box[2] * image.width)
        y_max = int(box[3] * image.height)
        draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
    image.save(save_path)


def preprocess_messages(history, img_path, platform_str, format_str):
    history_step = []
    for task, model_msg in history:
        grounded_pattern = r"Grounded Operation:\s*(.*)"
        matches_history = re.search(grounded_pattern, model_msg)
        if matches_history:
            grounded_operation = matches_history.group(1)
            history_step.append(grounded_operation)

    history_str = "\nHistory steps: "
    if history_step:
        for i, step in enumerate(history_step):
            history_str += f"\n{i}. {step}"

    if history:
        task = history[-1][0]
    else:
        task = "No task provided"

    query = f"Task: {task}{history_str}\n{platform_str}{format_str}"
    image = Image.open(img_path).convert("RGB")
    return query, image


@spaces.GPU()
def predict(history, max_length, img_path, platform_str, format_str, output_dir):
    # Reset the stop_event at the start of prediction
    stop_event.clear()

    # Remember history length before this round (for rollback if stopped)
    prev_len = len(history)

    query, image = preprocess_messages(history, img_path, platform_str, format_str)
    inputs = tokenizer.apply_chat_template(
        [{"role": "user", "image": image, "content": query}],
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "position_ids": inputs["position_ids"],
        "images": inputs["images"],
        "streamer": streamer,
        "max_length": max_length,
        "do_sample": True,
        "top_k": 1,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    for new_token in streamer:
        # Check if stop event is set
        if stop_event.is_set():
            # Stop generation immediately
            # Rollback the last round user input
            while len(history) > prev_len:
                history.pop()
            yield history, None
            return

        if new_token:
            history[-1][1] += new_token
        yield history, None

    # If finished without stop event
    response = history[-1][1]
    box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]"
    matches = re.findall(box_pattern, response)
    if matches:
        boxes = [[int(x) / 1000 for x in match] for match in matches]
        os.makedirs(output_dir, exist_ok=True)
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        round_num = sum(1 for (u, m) in history if u and m)
        output_path = os.path.join(output_dir, f"{base_name}_{round_num}.png")
        image = Image.open(img_path).convert("RGB")
        draw_boxes_on_image(image, boxes, output_path)
        yield history, output_path
    else:
        yield history, None


def user(task, history):
    return "", history + [[task, ""]]


def undo_last_round(history, output_img):
    if history:
        history.pop()
    return history, None


def clear_all_history():
    return None, None


def stop_now():
    stop_event.set()
    return gr.update(), gr.update()


def main():
    parser = argparse.ArgumentParser(description="CogAgent Gradio Demo")
    parser.add_argument("--model_dir", default="THUDM/cogagent-9b-20241220", help="Path or identifier of the model.")
    parser.add_argument("--format_key", default="action_op_sensitive", help="Key to select the prompt format.")
    parser.add_argument("--platform", default="Mac", help="Platform information string.")
    parser.add_argument("--output_dir", default="outputs", help="Directory to save annotated images.")
    args = parser.parse_args()

    format_dict = {
        "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)",
        "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)",
        "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)",
        "status_action_op": "(Answer in Status-Action-Operation format.)",
        "action_op": "(Answer in Action-Operation format.)"
    }

    if args.format_key not in format_dict:
        raise ValueError(f"Invalid format_key. Available keys: {list(format_dict.keys())}")

    global tokenizer, model
    tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        args.model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto"
    ).eval()

    platform_str = f"(Platform: {args.platform})\n"
    format_str = format_dict[args.format_key]

    with gr.Blocks(analytics_enabled=False) as demo:
        gr.HTML("<h1 align='center'>CogAgent-9B-20241220 Demo</h1>")
        gr.HTML(
            """
            <p align='center' style='color:red;'>This demo is for learning and communication purposes only. Users must assume responsibility for the risks associated with AI-generated planning and operations.</p>
            <p align='center' style='color:red;'>In this demo, the model assumes that the user is using a Mac operating system. Therefore, it is recommended to upload screenshots taken on a Mac.</p>
            <p align='left' style='color:black;'>1. Upload a screenshot from your computer (must be from a Mac, and a full-screen screenshot).<br>
            2. Provide your instructions to CogAgent (e.g., send a message to XXX).<br>
            3. Wait for CogAgent to return specific operations. If bounding boxes (Bbox) are detected, they will be displayed in the image area on the right.</p>
            <p align='left' style='color:black;'>The model will only return the next step's instructions. The online demo cannot control your computer. Please visit the <a href="https://github.com/THUDM/CogAgent">GitHub repository</a> for the full version of the demo.</p>
            """
        )
        with gr.Row():
            img_path = gr.Image(label="Upload a Screenshot", type="filepath", height=400)
            output_img = gr.Image(type="filepath", label="Annotated Image(If Bbox Return)", height=400, interactive=False)

        with gr.Row():
            with gr.Column(scale=2):
                chatbot = gr.Chatbot(height=300)
                task = gr.Textbox(show_label=True, placeholder="Input...", label="Task")
                submitBtn = gr.Button("Submit")
            with gr.Column(scale=1):
                max_length = gr.Slider(0, 8192, value=1024, step=1.0, label="Maximum length", interactive=True)
                undo_last_round_btn = gr.Button("Back to Last Round")
                clear_history_btn = gr.Button("Clear All History")

                # 添加红色的立刻中断按钮,点击后中断生成并回滚当前轮历史
                stop_now_btn = gr.Button("Stop Now", variant="stop")

        submitBtn.click(
            user, [task, chatbot], [task, chatbot], queue=False
        ).then(
            predict,
            [chatbot, max_length, img_path, gr.State(platform_str), gr.State(format_str),
             gr.State(args.output_dir)],
            [chatbot, output_img],
            queue=True
        )

        undo_last_round_btn.click(undo_last_round, [chatbot, output_img], [chatbot, output_img], queue=False)
        clear_history_btn.click(clear_all_history, None, [chatbot, output_img], queue=False)
        stop_now_btn.click(stop_now, None, [chatbot, output_img], queue=False)

    demo.queue()
    demo.launch()


if __name__ == "__main__":
    main()