qnguyen3 commited on
Commit
07368a8
·
verified ·
1 Parent(s): 53860b6

Update modeling_llava_qwen2.py

Browse files
Files changed (1) hide show
  1. modeling_llava_qwen2.py +5 -5
modeling_llava_qwen2.py CHANGED
@@ -535,13 +535,13 @@ class SigLipVisionTower(nn.Module):
535
  if type(images) is list:
536
  image_features = []
537
  for image in images:
538
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
539
  output_hidden_states=True)
540
  image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
541
  assert image_features.shape[-2] == 729
542
  image_features.append(image_feature)
543
  else:
544
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
545
  output_hidden_states=True)
546
  image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
547
  assert image_features.shape[-2] == 729
@@ -550,7 +550,7 @@ class SigLipVisionTower(nn.Module):
550
 
551
  @property
552
  def dummy_feature(self):
553
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
554
 
555
  @property
556
  def dtype(self):
@@ -682,9 +682,9 @@ class LlavaMetaForCausalLM(ABC):
682
  image_features = self.encode_images(concat_images)
683
  split_sizes = [image.shape[0] for image in images]
684
  image_features = torch.split(image_features, split_sizes, dim=0)
685
- image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
686
  else:
687
- image_features = self.encode_images(images).to(self.device)
688
 
689
  # Let's just add dummy tensors if they do not exist,
690
  # it is a headache to deal with None all the time.
 
535
  if type(images) is list:
536
  image_features = []
537
  for image in images:
538
+ image_forward_out = self.vision_tower(image.to(device="cuda:0", dtype=self.dtype).unsqueeze(0),
539
  output_hidden_states=True)
540
  image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
541
  assert image_features.shape[-2] == 729
542
  image_features.append(image_feature)
543
  else:
544
+ image_forward_outs = self.vision_tower(images.to(device="cuda:0", dtype=self.dtype),
545
  output_hidden_states=True)
546
  image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
547
  assert image_features.shape[-2] == 729
 
550
 
551
  @property
552
  def dummy_feature(self):
553
+ return torch.zeros(1, self.hidden_size, device="cuda:0", dtype=self.dtype)
554
 
555
  @property
556
  def dtype(self):
 
682
  image_features = self.encode_images(concat_images)
683
  split_sizes = [image.shape[0] for image in images]
684
  image_features = torch.split(image_features, split_sizes, dim=0)
685
+ image_features = [x.flatten(0, 1).to("cuda:0") for x in image_features]
686
  else:
687
+ image_features = self.encode_images(images).to("cuda:0")
688
 
689
  # Let's just add dummy tensors if they do not exist,
690
  # it is a headache to deal with None all the time.