Update modeling_mmMamba_embedding.py
Browse files
modeling_mmMamba_embedding.py
CHANGED
@@ -410,7 +410,7 @@ class MHA_LM(nn.Module):
|
|
410 |
):
|
411 |
if self.rotary_emb_dim > 0:
|
412 |
q, kv = self.rotary_emb(
|
413 |
-
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
414 |
)
|
415 |
if inference_params is None:
|
416 |
k, v = kv.unbind(dim=-3)
|
@@ -538,7 +538,9 @@ class Mamba2_LM(nn.Module):
|
|
538 |
conv_state, ssm_state = None, None
|
539 |
if inference_params is not None:
|
540 |
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
541 |
-
|
|
|
|
|
542 |
if use_cache and inference_params.seqlen_offset==0:
|
543 |
vkq, new_conv_states = causal_conv1d_fn(
|
544 |
vkq.transpose(1, 2),
|
|
|
410 |
):
|
411 |
if self.rotary_emb_dim > 0:
|
412 |
q, kv = self.rotary_emb(
|
413 |
+
q, kv, seqlen_offset=seqlen_offset[:bsz,...], max_seqlen=rotary_max_seqlen
|
414 |
)
|
415 |
if inference_params is None:
|
416 |
k, v = kv.unbind(dim=-3)
|
|
|
538 |
conv_state, ssm_state = None, None
|
539 |
if inference_params is not None:
|
540 |
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
541 |
+
conv_state = conv_state[:batch, ...]
|
542 |
+
ssm_state = ssm_state[:batch, ...]
|
543 |
+
|
544 |
if use_cache and inference_params.seqlen_offset==0:
|
545 |
vkq, new_conv_states = causal_conv1d_fn(
|
546 |
vkq.transpose(1, 2),
|