SmerkyG commited on
Commit
07f9d2e
1 Parent(s): 058a551

Update modeling_rwkv6qwen2.py

Browse files

added check for fla import requirement

Files changed (1) hide show
  1. modeling_rwkv6qwen2.py +8 -2
modeling_rwkv6qwen2.py CHANGED
@@ -204,8 +204,14 @@ class RWKV6State(Cache):
204
  # self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
205
  # self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
206
 
207
- from fla.ops.gla.chunk import chunk_gla
208
- from fla.ops.gla.fused_recurrent import fused_recurrent_gla
 
 
 
 
 
 
209
 
210
  class RWKV6Attention(nn.Module):
211
  def __init__(self, config, layer_idx: Optional[int] = None):
 
204
  # self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
205
  # self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
206
 
207
+ try:
208
+ #from fla.ops.gla.chunk import chunk_gla
209
+ from fla.ops.gla.fused_recurrent import fused_recurrent_gla
210
+ except ImportError:
211
+ print("Required module is not installed. Please install it using the following commands:")
212
+ print("pip install -U git+https://github.com/sustcsonglin/flash-linear-attention")
213
+ print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
214
+ print("pip install triton>=2.2.0")
215
 
216
  class RWKV6Attention(nn.Module):
217
  def __init__(self, config, layer_idx: Optional[int] = None):