qnguyen3 commited on
Commit
40f0486
·
verified ·
1 Parent(s): 85470c7

Update modeling_llava_qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_llava_qwen2.py +5 -5
modeling_llava_qwen2.py CHANGED
@@ -538,13 +538,13 @@ class SigLipVisionTower(nn.Module):
538
  if type(images) is list:
539
  image_features = []
540
  for image in images:
541
- image_forward_out = self.vision_tower(image.to(device="cuda:0", dtype=self.dtype).unsqueeze(0),
542
  output_hidden_states=True)
543
  image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
544
  assert image_features.shape[-2] == 729
545
  image_features.append(image_feature)
546
  else:
547
- image_forward_outs = self.vision_tower(images.to(device="cuda:0", dtype=self.dtype),
548
  output_hidden_states=True)
549
  image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
550
  assert image_features.shape[-2] == 729
@@ -553,7 +553,7 @@ class SigLipVisionTower(nn.Module):
553
 
554
  @property
555
  def dummy_feature(self):
556
- return torch.zeros(1, self.hidden_size, device="cuda:0", dtype=self.dtype)
557
 
558
  @property
559
  def dtype(self):
@@ -685,9 +685,9 @@ class LlavaMetaForCausalLM(ABC):
685
  image_features = self.encode_images(concat_images)
686
  split_sizes = [image.shape[0] for image in images]
687
  image_features = torch.split(image_features, split_sizes, dim=0)
688
- image_features = [x.flatten(0, 1).to("cuda:0") for x in image_features]
689
  else:
690
- image_features = self.encode_images(images).to("cuda:0")
691
 
692
  # Let's just add dummy tensors if they do not exist,
693
  # it is a headache to deal with None all the time.
 
538
  if type(images) is list:
539
  image_features = []
540
  for image in images:
541
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
542
  output_hidden_states=True)
543
  image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
544
  assert image_features.shape[-2] == 729
545
  image_features.append(image_feature)
546
  else:
547
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
548
  output_hidden_states=True)
549
  image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
550
  assert image_features.shape[-2] == 729
 
553
 
554
  @property
555
  def dummy_feature(self):
556
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
557
 
558
  @property
559
  def dtype(self):
 
685
  image_features = self.encode_images(concat_images)
686
  split_sizes = [image.shape[0] for image in images]
687
  image_features = torch.split(image_features, split_sizes, dim=0)
688
+ image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
689
  else:
690
+ image_features = self.encode_images(images).to(self.device)
691
 
692
  # Let's just add dummy tensors if they do not exist,
693
  # it is a headache to deal with None all the time.