Update modeling_qwen.py
Browse files- modeling_qwen.py +5 -2
modeling_qwen.py
CHANGED
@@ -565,10 +565,13 @@ class QWenModel(QWenPreTrainedModel):
|
|
565 |
images = self.visual.encode(images)
|
566 |
assert images.shape[0] == len(images)
|
567 |
fake_images = None
|
568 |
-
|
569 |
fake_images=torch.zeros(1,3,224,224).to(
|
570 |
dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
|
571 |
images = self.visual(fake_images)
|
|
|
|
|
|
|
572 |
|
573 |
output_attentions = (
|
574 |
output_attentions
|
@@ -657,7 +660,7 @@ class QWenModel(QWenPreTrainedModel):
|
|
657 |
hidden_states = self.drop(hidden_states).clone()
|
658 |
if fake_images is not None:
|
659 |
hidden_states = hidden_states + images.mean()*0
|
660 |
-
|
661 |
for idx, (i, a, b) in enumerate(img_pos):
|
662 |
hidden_states[i][a + 1 : b] = images[idx]
|
663 |
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
565 |
images = self.visual.encode(images)
|
566 |
assert images.shape[0] == len(images)
|
567 |
fake_images = None
|
568 |
+
elif self.training:
|
569 |
fake_images=torch.zeros(1,3,224,224).to(
|
570 |
dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device)
|
571 |
images = self.visual(fake_images)
|
572 |
+
else:
|
573 |
+
fake_images = None
|
574 |
+
images = None
|
575 |
|
576 |
output_attentions = (
|
577 |
output_attentions
|
|
|
660 |
hidden_states = self.drop(hidden_states).clone()
|
661 |
if fake_images is not None:
|
662 |
hidden_states = hidden_states + images.mean()*0
|
663 |
+
elif images is not None:
|
664 |
for idx, (i, a, b) in enumerate(img_pos):
|
665 |
hidden_states[i][a + 1 : b] = images[idx]
|
666 |
output_shape = input_shape + (hidden_states.size(-1),)
|