Update modeling_rwkv6qwen2.py
Browse filesadded check for fla import requirement
- modeling_rwkv6qwen2.py +8 -2
modeling_rwkv6qwen2.py
CHANGED
@@ -204,8 +204,14 @@ class RWKV6State(Cache):
|
|
204 |
# self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
205 |
# self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
206 |
|
207 |
-
|
208 |
-
from fla.ops.gla.
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
class RWKV6Attention(nn.Module):
|
211 |
def __init__(self, config, layer_idx: Optional[int] = None):
|
|
|
204 |
# self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
205 |
# self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
206 |
|
207 |
+
try:
|
208 |
+
#from fla.ops.gla.chunk import chunk_gla
|
209 |
+
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
210 |
+
except ImportError:
|
211 |
+
print("Required module is not installed. Please install it using the following commands:")
|
212 |
+
print("pip install -U git+https://github.com/sustcsonglin/flash-linear-attention")
|
213 |
+
print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
|
214 |
+
print("pip install triton>=2.2.0")
|
215 |
|
216 |
class RWKV6Attention(nn.Module):
|
217 |
def __init__(self, config, layer_idx: Optional[int] = None):
|