Update modeling_mmMamba.py
Browse files- modeling_mmMamba.py +5 -7
modeling_mmMamba.py
CHANGED
@@ -24,22 +24,20 @@ import torch.nn.functional as F
|
|
24 |
import torch.utils.checkpoint
|
25 |
from einops import rearrange
|
26 |
from torch import nn
|
27 |
-
from torch.nn import
|
28 |
from transformers.activations import ACT2FN
|
29 |
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
30 |
-
CausalLMOutputWithPast
|
31 |
-
SequenceClassifierOutputWithPast)
|
32 |
from transformers.modeling_utils import PreTrainedModel
|
33 |
from transformers.utils import (add_start_docstrings,
|
34 |
add_start_docstrings_to_model_forward, logging,
|
35 |
replace_return_docstrings)
|
36 |
-
from
|
37 |
-
|
38 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
39 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
40 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
41 |
-
|
42 |
-
import time
|
43 |
|
44 |
try:
|
45 |
from transformers.generation.streamers import BaseStreamer
|
|
|
24 |
import torch.utils.checkpoint
|
25 |
from einops import rearrange
|
26 |
from torch import nn
|
27 |
+
from torch.nn import CrossEntropyLoss
|
28 |
from transformers.activations import ACT2FN
|
29 |
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
30 |
+
CausalLMOutputWithPast)
|
|
|
31 |
from transformers.modeling_utils import PreTrainedModel
|
32 |
from transformers.utils import (add_start_docstrings,
|
33 |
add_start_docstrings_to_model_forward, logging,
|
34 |
replace_return_docstrings)
|
35 |
+
from fused_norm_gate import FusedRMSNormSwishGate
|
36 |
+
|
37 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
38 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
39 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
40 |
+
|
|
|
41 |
|
42 |
try:
|
43 |
from transformers.generation.streamers import BaseStreamer
|