make xformers an optional dependency

#6
by NyxKrage - opened
Files changed (1) hide show
  1. model.py +24 -2
model.py CHANGED
@@ -9,7 +9,11 @@ from torch.nn.functional import scaled_dot_product_attention
9
  from typing import Optional
10
  import numpy as np
11
 
12
- from xformers.ops import SwiGLU
 
 
 
 
13
 
14
  try:
15
  from flash_attn.flash_attn_interface import flash_attn_varlen_func
@@ -100,6 +104,21 @@ class NeoBERTConfig(PretrainedConfig):
100
  self.max_length = max_length
101
  self.kwargs = kwargs
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  class EncoderBlock(nn.Module):
105
  """Transformer encoder block."""
@@ -117,7 +136,10 @@ class EncoderBlock(nn.Module):
117
  multiple_of = 8
118
  intermediate_size = int(2 * config.intermediate_size / 3)
119
  intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
120
- self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
 
 
 
121
 
122
  # Layer norms
123
  self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
 
9
  from typing import Optional
10
  import numpy as np
11
 
12
+ try:
13
+ from xformers.ops import SwiGLU
14
+ XFORMERS_AVAILABLE = True
15
+ except ImportError:
16
+ XFORMERS_AVAILABLE = False
17
 
18
  try:
19
  from flash_attn.flash_attn_interface import flash_attn_varlen_func
 
104
  self.max_length = max_length
105
  self.kwargs = kwargs
106
 
107
+ # Adapted from transformers.models.llama.modeling_llama.LlamaMLP
108
+ class NeobertMLP(nn.Module):
109
+ def __init__(self, hidden_size, intermediate_size, bias=False):
110
+ super().__init__()
111
+ self.hidden_size = hidden_size
112
+ self.intermediate_size = intermediate_size
113
+ self.w12 = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=bias)
114
+ self.w3 = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
115
+ self.act_fn = nn.SiLU()
116
+
117
+ def forward(self, x):
118
+ w1, w2 = self.w12(x).chunk(2, dim=-1)
119
+ w3 = self.w3(self.act_fn(w1) * w2)
120
+ return w3
121
+
122
 
123
  class EncoderBlock(nn.Module):
124
  """Transformer encoder block."""
 
136
  multiple_of = 8
137
  intermediate_size = int(2 * config.intermediate_size / 3)
138
  intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
139
+ if XFORMERS_AVAILABLE:
140
+ self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
141
+ else:
142
+ self.ffn = NeobertMLP(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
143
 
144
  # Layer norms
145
  self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)