HongyuanTao commited on
Commit
f3e2a30
·
verified ·
1 Parent(s): 68654a2

Update modeling_mmMamba.py

Browse files
Files changed (1) hide show
  1. modeling_mmMamba.py +5 -7
modeling_mmMamba.py CHANGED
@@ -24,22 +24,20 @@ import torch.nn.functional as F
24
  import torch.utils.checkpoint
25
  from einops import rearrange
26
  from torch import nn
27
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
  from transformers.activations import ACT2FN
29
  from transformers.modeling_outputs import (BaseModelOutputWithPast,
30
- CausalLMOutputWithPast,
31
- SequenceClassifierOutputWithPast)
32
  from transformers.modeling_utils import PreTrainedModel
33
  from transformers.utils import (add_start_docstrings,
34
  add_start_docstrings_to_model_forward, logging,
35
  replace_return_docstrings)
36
- from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
37
- import copy
38
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
39
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
40
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
41
- from transformers.cache_utils import Cache
42
- import time
43
 
44
  try:
45
  from transformers.generation.streamers import BaseStreamer
 
24
  import torch.utils.checkpoint
25
  from einops import rearrange
26
  from torch import nn
27
+ from torch.nn import CrossEntropyLoss
28
  from transformers.activations import ACT2FN
29
  from transformers.modeling_outputs import (BaseModelOutputWithPast,
30
+ CausalLMOutputWithPast)
 
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.utils import (add_start_docstrings,
33
  add_start_docstrings_to_model_forward, logging,
34
  replace_return_docstrings)
35
+ from fused_norm_gate import FusedRMSNormSwishGate
36
+
37
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
38
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
39
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
40
+
 
41
 
42
  try:
43
  from transformers.generation.streamers import BaseStreamer