Leyo commited on
Commit
b298ee1
·
1 Parent(s): 8254757

fix issues with erf and xavier init

Browse files
Files changed (1) hide show
  1. modeling_siglip.py +15 -9
modeling_siglip.py CHANGED
@@ -95,7 +95,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
95
 
96
  # Use inverse cdf transform for normal distribution to get truncated
97
  # standard normal
98
- tensor.erfinv_()
 
 
 
 
 
99
 
100
  # Transform to proper mean, std
101
  tensor.mul_(std * math.sqrt(2.0))
@@ -670,6 +675,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
670
 
671
  def _init_weights(self, module):
672
  """Initialize the weights"""
 
673
  if isinstance(module, SiglipVisionEmbeddings):
674
  width = (
675
  self.config.vision_config.hidden_size
@@ -680,22 +686,22 @@ class SiglipPreTrainedModel(PreTrainedModel):
680
  elif isinstance(module, nn.Embedding):
681
  default_flax_embed_init(module.weight)
682
  elif isinstance(module, SiglipAttention):
683
- nn.init.xavier_uniform_(module.q_proj.weight)
684
- nn.init.xavier_uniform_(module.k_proj.weight)
685
- nn.init.xavier_uniform_(module.v_proj.weight)
686
- nn.init.xavier_uniform_(module.out_proj.weight)
687
  nn.init.zeros_(module.q_proj.bias)
688
  nn.init.zeros_(module.k_proj.bias)
689
  nn.init.zeros_(module.v_proj.bias)
690
  nn.init.zeros_(module.out_proj.bias)
691
  elif isinstance(module, SiglipMLP):
692
- nn.init.xavier_uniform_(module.fc1.weight)
693
- nn.init.xavier_uniform_(module.fc2.weight)
694
  nn.init.normal_(module.fc1.bias, std=1e-6)
695
  nn.init.normal_(module.fc2.bias, std=1e-6)
696
  elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
697
- nn.init.xavier_uniform_(module.probe.data)
698
- nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
699
  nn.init.zeros_(module.attention.in_proj_bias.data)
700
  elif isinstance(module, SiglipModel):
701
  logit_scale_init = torch.log(torch.tensor(1.0))
 
95
 
96
  # Use inverse cdf transform for normal distribution to get truncated
97
  # standard normal
98
+ if tensor.dtype == torch.bfloat16:
99
+ tensor = tensor.to(torch.float32)
100
+ tensor.erfinv_()
101
+ tensor = tensor.to(torch.bfloat16)
102
+ else:
103
+ tensor.erfinv_()
104
 
105
  # Transform to proper mean, std
106
  tensor.mul_(std * math.sqrt(2.0))
 
675
 
676
  def _init_weights(self, module):
677
  """Initialize the weights"""
678
+
679
  if isinstance(module, SiglipVisionEmbeddings):
680
  width = (
681
  self.config.vision_config.hidden_size
 
686
  elif isinstance(module, nn.Embedding):
687
  default_flax_embed_init(module.weight)
688
  elif isinstance(module, SiglipAttention):
689
+ nn.init.normal_(module.q_proj.weight)
690
+ nn.init.normal_(module.k_proj.weight)
691
+ nn.init.normal_(module.v_proj.weight)
692
+ nn.init.normal_(module.out_proj.weight)
693
  nn.init.zeros_(module.q_proj.bias)
694
  nn.init.zeros_(module.k_proj.bias)
695
  nn.init.zeros_(module.v_proj.bias)
696
  nn.init.zeros_(module.out_proj.bias)
697
  elif isinstance(module, SiglipMLP):
698
+ nn.init.normal_(module.fc1.weight)
699
+ nn.init.normal_(module.fc2.weight)
700
  nn.init.normal_(module.fc1.bias, std=1e-6)
701
  nn.init.normal_(module.fc2.bias, std=1e-6)
702
  elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
703
+ nn.init.normal_(module.probe.data)
704
+ nn.init.normal_(module.attention.in_proj_weight.data)
705
  nn.init.zeros_(module.attention.in_proj_bias.data)
706
  elif isinstance(module, SiglipModel):
707
  logit_scale_init = torch.log(torch.tensor(1.0))