momergul commited on
Commit
695707a
·
1 Parent(s): 7ef6c1a
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -36,7 +36,6 @@ def initialize_game() -> List[List[str]]:
36
 
37
  return list(zip(speaker_images, listener_images, targets, roles))
38
 
39
- @spaces.GPU(duration=10)
40
  def get_model_response(
41
  model, adapter_name, processor, index_to_token, role: str,
42
  image_paths: List[str], user_message: str = "", target_image: str = ""
@@ -72,10 +71,10 @@ def get_model_response(
72
 
73
  @spaces.GPU(duration=15)
74
  def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token):
 
75
  with torch.no_grad():
76
- print(model.model.device, images.device)
77
  captions, _, _, _, _ = model.generate(
78
- images, input_tokens, attn_mask, image_attn_mask, label,
79
  image_paths, processor, img_dir, index_to_token,
80
  max_steps=30, sampling_type="nucleus", temperature=0.7,
81
  top_k=50, top_p=1, repetition_penalty=1, num_samples=5
@@ -85,11 +84,12 @@ def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask
85
  @spaces.GPU(duration=15)
86
  def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
87
  s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths):
 
88
  with torch.no_grad():
89
  print(model.model.device, images.device)
90
  _, _, joint_log_probs = model.comprehension_side([
91
- images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
92
- s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label,
93
  ])
94
  target_idx = joint_log_probs[0].argmax().item()
95
  response = image_paths[target_idx]
 
36
 
37
  return list(zip(speaker_images, listener_images, targets, roles))
38
 
 
39
  def get_model_response(
40
  model, adapter_name, processor, index_to_token, role: str,
41
  image_paths: List[str], user_message: str = "", target_image: str = ""
 
71
 
72
  @spaces.GPU(duration=15)
73
  def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token):
74
+ model = model.cuda()
75
  with torch.no_grad():
 
76
  captions, _, _, _, _ = model.generate(
77
+ images.cuda(), input_tokens.cuda(), attn_mask.cuda(), image_attn_mask.cuda(), label.cuda(),
78
  image_paths, processor, img_dir, index_to_token,
79
  max_steps=30, sampling_type="nucleus", temperature=0.7,
80
  top_k=50, top_p=1, repetition_penalty=1, num_samples=5
 
84
  @spaces.GPU(duration=15)
85
  def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
86
  s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths):
87
+ model = model.cuda()
88
  with torch.no_grad():
89
  print(model.model.device, images.device)
90
  _, _, joint_log_probs = model.comprehension_side([
91
+ images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token,
92
+ s_input_tokens.cuda(), s_attn_mask.cuda(), s_image_attn_mask.cuda(), s_target_mask.cuda(), s_target_label.cuda(),
93
  ])
94
  target_idx = joint_log_probs[0].argmax().item()
95
  response = image_paths[target_idx]