czczup commited on
Commit
4b3bcd0
1 Parent(s): db62f28

norm function changes dtypes (#13)

Browse files

- convert norm function output to original dtype (771117017599deb36494828a10d72790c69929f6)

Files changed (1) hide show
  1. modeling_intern_vit.py +3 -2
modeling_intern_vit.py CHANGED
@@ -287,9 +287,10 @@ class InternVisionEncoderLayer(nn.Module):
287
  Args:
288
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
289
  """
290
- hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
291
 
292
- hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
 
 
293
 
294
  return hidden_states
295
 
 
287
  Args:
288
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
289
  """
 
290
 
291
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
292
+
293
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
294
 
295
  return hidden_states
296