srimanth-d commited on
Commit
5b65c54
·
verified ·
1 Parent(s): 646d03f

Update modeling_got.py

Browse files
Files changed (1) hide show
  1. modeling_got.py +4 -4
modeling_got.py CHANGED
@@ -583,7 +583,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
583
  elif device == "mps" or device == "cpu":
584
  output_ids = self.generate(
585
  input_ids,
586
- images=[image_tensor_1.unsqueeze(0).half().to(device)],
587
  do_sample=False,
588
  num_beams = 1,
589
  no_repeat_ngram_size = 20,
@@ -609,7 +609,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
609
  elif device == "mps" or device == "cpu":
610
  output_ids = self.generate(
611
  input_ids,
612
- images=[image_tensor_1.unsqueeze(0).half().to(device)],
613
  do_sample=False,
614
  num_beams = 1,
615
  no_repeat_ngram_size = 20,
@@ -865,7 +865,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
865
  elif device == "mps" or device == "cpu":
866
  output_ids = self.generate(
867
  input_ids,
868
- images=[image_list.half().to(device)],
869
  do_sample=False,
870
  num_beams = 1,
871
  # no_repeat_ngram_size = 20,
@@ -891,7 +891,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
891
  elif device == "mps" or device == "cpu":
892
  output_ids = self.generate(
893
  input_ids,
894
- images=[image_list.half().to(device)],
895
  do_sample=False,
896
  num_beams = 1,
897
  # no_repeat_ngram_size = 20,
 
583
  elif device == "mps" or device == "cpu":
584
  output_ids = self.generate(
585
  input_ids,
586
+ images=[image_tensor_1.unsqueeze(0).to(device)],
587
  do_sample=False,
588
  num_beams = 1,
589
  no_repeat_ngram_size = 20,
 
609
  elif device == "mps" or device == "cpu":
610
  output_ids = self.generate(
611
  input_ids,
612
+ images=[image_tensor_1.unsqueeze(0).to(device)],
613
  do_sample=False,
614
  num_beams = 1,
615
  no_repeat_ngram_size = 20,
 
865
  elif device == "mps" or device == "cpu":
866
  output_ids = self.generate(
867
  input_ids,
868
+ images=[image_list.to(device)],
869
  do_sample=False,
870
  num_beams = 1,
871
  # no_repeat_ngram_size = 20,
 
891
  elif device == "mps" or device == "cpu":
892
  output_ids = self.generate(
893
  input_ids,
894
+ images=[image_list.to(device)],
895
  do_sample=False,
896
  num_beams = 1,
897
  # no_repeat_ngram_size = 20,