praeclarumjj3 commited on
Commit
82a8364
·
verified ·
1 Parent(s): ce60be5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -33
app.py CHANGED
@@ -251,44 +251,41 @@ def regenerate(state, image_process_mode):
251
 
252
  @spaces.GPU
253
  def get_interm_outs(state):
254
- print("HERERERE")
255
- print(state)
256
  prompt = state.get_prompt()
257
- print(prompt)
258
  images = state.get_images(return_pil=True)
259
  #prompt, image_args = process_image(prompt, images)
260
 
261
- # if images is not None and len(images) > 0:
262
- # if len(images) > 0:
263
- # if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
264
- # raise ValueError("Number of images does not match number of <image> tokens in prompt")
265
 
266
- # #images = [load_image_from_base64(image) for image in images]
267
- # image_sizes = [image.size for image in images]
268
- # inp_images = process_images(images, image_processor, model.config)
269
-
270
- # if type(inp_images) is list:
271
- # inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
272
- # else:
273
- # inp_images = inp_images.to(model.device, dtype=torch.float16)
274
- # else:
275
- # inp_images = None
276
- # image_sizes = None
277
- # image_args = {"images": inp_images, "image_sizes": image_sizes}
278
- # else:
279
- # inp_images = None
280
- # image_args = {}
281
-
282
- # input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
283
-
284
- # interm_outs = model.get_visual_interpretations(
285
- # input_ids,
286
- # **image_args
287
- # )
288
 
289
- # depth_outs = get_depth_images(interm_outs, image_sizes[0])
290
- # seg_outs = get_seg_images(interm_outs, images[0])
291
- # gen_outs = get_gen_images(interm_outs)
292
 
293
  return images[0], images[0], images[0]
294
 
@@ -450,7 +447,7 @@ with gr.Blocks(title="OLA-VLM", theme=gr.themes.Default(), css=block_css) as dem
450
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
451
 
452
  inter_vis_btn.click(
453
- generate,
454
  [state],
455
  [depth_box, seg_box, gen_box],
456
  )
 
251
 
252
  @spaces.GPU
253
  def get_interm_outs(state):
 
 
254
  prompt = state.get_prompt()
 
255
  images = state.get_images(return_pil=True)
256
  #prompt, image_args = process_image(prompt, images)
257
 
258
+ if images is not None and len(images) > 0:
259
+ if len(images) > 0:
260
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
261
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
262
 
263
+ #images = [load_image_from_base64(image) for image in images]
264
+ image_sizes = [image.size for image in images]
265
+ inp_images = process_images(images, image_processor, model.config)
266
+
267
+ if type(inp_images) is list:
268
+ inp_images = [image.to(model.device, dtype=torch.float16) for image in images]
269
+ else:
270
+ inp_images = inp_images.to(model.device, dtype=torch.float16)
271
+ else:
272
+ inp_images = None
273
+ image_sizes = None
274
+ image_args = {"images": inp_images, "image_sizes": image_sizes}
275
+ else:
276
+ inp_images = None
277
+ image_args = {}
278
+
279
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
280
+
281
+ interm_outs = model.get_visual_interpretations(
282
+ input_ids,
283
+ **image_args
284
+ )
285
 
286
+ depth_outs = get_depth_images(interm_outs, image_sizes[0])
287
+ seg_outs = get_seg_images(interm_outs, images[0])
288
+ gen_outs = get_gen_images(interm_outs)
289
 
290
  return images[0], images[0], images[0]
291
 
 
447
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
448
 
449
  inter_vis_btn.click(
450
+ get_interm_outs,
451
  [state],
452
  [depth_box, seg_box, gen_box],
453
  )