zR commited on
Commit
d944583
·
1 Parent(s): 3deb6e9
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -82,7 +82,7 @@ def predict(history, max_length, img_path, platform_str, format_str, output_dir)
82
  prev_len = len(history)
83
 
84
  query, image = preprocess_messages(history, img_path, platform_str, format_str)
85
- model_inputs = tokenizer.apply_chat_template(
86
  [{"role": "user", "image": image, "content": query}],
87
  add_generation_prompt=True,
88
  tokenize=True,
@@ -94,12 +94,13 @@ def predict(history, max_length, img_path, platform_str, format_str, output_dir)
94
  tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
95
  )
96
  generate_kwargs = {
97
- "input_ids": model_inputs["input_ids"].to(model.device),
98
- "attention_mask": model_inputs["attention_mask"].to(model.device),
 
 
99
  "streamer": streamer,
100
  "max_length": max_length,
101
- "do_sample": False,
102
- "top_p": 0.0,
103
  "top_k": 1,
104
  }
105
  t = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -187,7 +188,8 @@ def main():
187
  gr.HTML("<h1 align='center'>CogAgent-9B-20241220 Demo</h1>")
188
  gr.HTML(
189
  "<p align='center' style='color:red;'>This demo is for learning and communication purposes only. Users must assume responsibility for the risks associated with AI-generated planning and operations.</p>"
190
- "<p align='left' style='color:black;'>1. Upload an image. 2. Provide your instructions to CogAgent. 3. Wait for CogAgent to return specific operations, and if there are bounding boxes (Bbox), they will be displayed in the image area on the right.</p>"
 
191
  )
192
  with gr.Row():
193
  img_path = gr.Image(label="Upload a Screenshot", type="filepath", height=400)
 
82
  prev_len = len(history)
83
 
84
  query, image = preprocess_messages(history, img_path, platform_str, format_str)
85
+ inputs = tokenizer.apply_chat_template(
86
  [{"role": "user", "image": image, "content": query}],
87
  add_generation_prompt=True,
88
  tokenize=True,
 
94
  tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
95
  )
96
  generate_kwargs = {
97
+ "input_ids": inputs["input_ids"],
98
+ "attention_mask": inputs["attention_mask"],
99
+ "position_ids": inputs["position_ids"],
100
+ "images": inputs["images"],
101
  "streamer": streamer,
102
  "max_length": max_length,
103
+ "do_sample": True,
 
104
  "top_k": 1,
105
  }
106
  t = Thread(target=model.generate, kwargs=generate_kwargs)
 
188
  gr.HTML("<h1 align='center'>CogAgent-9B-20241220 Demo</h1>")
189
  gr.HTML(
190
  "<p align='center' style='color:red;'>This demo is for learning and communication purposes only. Users must assume responsibility for the risks associated with AI-generated planning and operations.</p>"
191
+ "<p align='center' style='color:red;'>In this demo, the model assumes that the user is using a Mac operating system, so it is recommended to upload screenshots from a Mac operating system.</p>"
192
+ "<p align='left' style='color:black;'>1. Upload an image.<br>2. Provide your instructions to CogAgent.<br>3. Wait for CogAgent to return specific operations. If there are bounding boxes (Bbox), they will be displayed in the image area on the right.</p>"
193
  )
194
  with gr.Row():
195
  img_path = gr.Image(label="Upload a Screenshot", type="filepath", height=400)