momergul
commited on
Commit
·
695707a
1
Parent(s):
7ef6c1a
Update
Browse files
app.py
CHANGED
@@ -36,7 +36,6 @@ def initialize_game() -> List[List[str]]:
|
|
36 |
|
37 |
return list(zip(speaker_images, listener_images, targets, roles))
|
38 |
|
39 |
-
@spaces.GPU(duration=10)
|
40 |
def get_model_response(
|
41 |
model, adapter_name, processor, index_to_token, role: str,
|
42 |
image_paths: List[str], user_message: str = "", target_image: str = ""
|
@@ -72,10 +71,10 @@ def get_model_response(
|
|
72 |
|
73 |
@spaces.GPU(duration=15)
|
74 |
def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token):
|
|
|
75 |
with torch.no_grad():
|
76 |
-
print(model.model.device, images.device)
|
77 |
captions, _, _, _, _ = model.generate(
|
78 |
-
images, input_tokens, attn_mask, image_attn_mask, label,
|
79 |
image_paths, processor, img_dir, index_to_token,
|
80 |
max_steps=30, sampling_type="nucleus", temperature=0.7,
|
81 |
top_k=50, top_p=1, repetition_penalty=1, num_samples=5
|
@@ -85,11 +84,12 @@ def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask
|
|
85 |
@spaces.GPU(duration=15)
|
86 |
def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
|
87 |
s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths):
|
|
|
88 |
with torch.no_grad():
|
89 |
print(model.model.device, images.device)
|
90 |
_, _, joint_log_probs = model.comprehension_side([
|
91 |
-
images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
|
92 |
-
s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label,
|
93 |
])
|
94 |
target_idx = joint_log_probs[0].argmax().item()
|
95 |
response = image_paths[target_idx]
|
|
|
36 |
|
37 |
return list(zip(speaker_images, listener_images, targets, roles))
|
38 |
|
|
|
39 |
def get_model_response(
|
40 |
model, adapter_name, processor, index_to_token, role: str,
|
41 |
image_paths: List[str], user_message: str = "", target_image: str = ""
|
|
|
71 |
|
72 |
@spaces.GPU(duration=15)
|
73 |
def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token):
|
74 |
+
model = model.cuda()
|
75 |
with torch.no_grad():
|
|
|
76 |
captions, _, _, _, _ = model.generate(
|
77 |
+
images.cuda(), input_tokens.cuda(), attn_mask.cuda(), image_attn_mask.cuda(), label.cuda(),
|
78 |
image_paths, processor, img_dir, index_to_token,
|
79 |
max_steps=30, sampling_type="nucleus", temperature=0.7,
|
80 |
top_k=50, top_p=1, repetition_penalty=1, num_samples=5
|
|
|
84 |
@spaces.GPU(duration=15)
|
85 |
def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
|
86 |
s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths):
|
87 |
+
model = model.cuda()
|
88 |
with torch.no_grad():
|
89 |
print(model.model.device, images.device)
|
90 |
_, _, joint_log_probs = model.comprehension_side([
|
91 |
+
images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token,
|
92 |
+
s_input_tokens.cuda(), s_attn_mask.cuda(), s_image_attn_mask.cuda(), s_target_mask.cuda(), s_target_label.cuda(),
|
93 |
])
|
94 |
target_idx = joint_log_probs[0].argmax().item()
|
95 |
response = image_paths[target_idx]
|