HongyuanTao commited on
Commit
975bc02
·
verified ·
1 Parent(s): 300e9bb

Update modeling_mmMamba_embedding.py

Browse files
Files changed (1) hide show
  1. modeling_mmMamba_embedding.py +4 -2
modeling_mmMamba_embedding.py CHANGED
@@ -410,7 +410,7 @@ class MHA_LM(nn.Module):
410
  ):
411
  if self.rotary_emb_dim > 0:
412
  q, kv = self.rotary_emb(
413
- q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
414
  )
415
  if inference_params is None:
416
  k, v = kv.unbind(dim=-3)
@@ -538,7 +538,9 @@ class Mamba2_LM(nn.Module):
538
  conv_state, ssm_state = None, None
539
  if inference_params is not None:
540
  conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
541
-
 
 
542
  if use_cache and inference_params.seqlen_offset==0:
543
  vkq, new_conv_states = causal_conv1d_fn(
544
  vkq.transpose(1, 2),
 
410
  ):
411
  if self.rotary_emb_dim > 0:
412
  q, kv = self.rotary_emb(
413
+ q, kv, seqlen_offset=seqlen_offset[:bsz,...], max_seqlen=rotary_max_seqlen
414
  )
415
  if inference_params is None:
416
  k, v = kv.unbind(dim=-3)
 
538
  conv_state, ssm_state = None, None
539
  if inference_params is not None:
540
  conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
541
+ conv_state = conv_state[:batch, ...]
542
+ ssm_state = ssm_state[:batch, ...]
543
+
544
  if use_cache and inference_params.seqlen_offset==0:
545
  vkq, new_conv_states = causal_conv1d_fn(
546
  vkq.transpose(1, 2),