Update modeling_mmMamba_embedding.py
Browse files
modeling_mmMamba_embedding.py
CHANGED
@@ -14,52 +14,30 @@
|
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
import math
|
17 |
-
import
|
18 |
-
import threading
|
19 |
-
import warnings
|
20 |
-
from typing import List, Optional, Tuple, Union
|
21 |
-
from functools import partial
|
22 |
|
23 |
import torch
|
24 |
import torch.nn.functional as F
|
25 |
import torch.utils.checkpoint
|
26 |
from einops import rearrange
|
27 |
from torch import nn
|
28 |
-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
from transformers.activations import ACT2FN
|
30 |
-
|
31 |
-
BaseModelOutputWithPast,
|
32 |
-
CausalLMOutputWithPast,
|
33 |
-
SequenceClassifierOutputWithPast,
|
34 |
-
)
|
35 |
from transformers.modeling_utils import PreTrainedModel
|
36 |
-
from transformers.
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
logging,
|
41 |
-
replace_return_docstrings,
|
42 |
-
)
|
43 |
-
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
|
44 |
-
import copy
|
45 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
46 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
47 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
48 |
-
|
49 |
-
import time
|
50 |
from timm.models.layers import DropPath
|
51 |
|
52 |
compute_ARank = False # [ARank] Set this to True to compute attention rank
|
53 |
|
54 |
-
try:
|
55 |
-
from transformers.generation.streamers import BaseStreamer
|
56 |
-
except: # noqa # pylint: disable=bare-except
|
57 |
-
BaseStreamer = None
|
58 |
-
|
59 |
from .configuration_mmMamba_embedding import mmMambaEmbeddingConfig
|
60 |
|
61 |
-
import time
|
62 |
-
|
63 |
from .configuration_mmMamba import mmMambaConfig
|
64 |
|
65 |
try:
|
|
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
import math
|
17 |
+
from typing import Optional, Tuple
|
|
|
|
|
|
|
|
|
18 |
|
19 |
import torch
|
20 |
import torch.nn.functional as F
|
21 |
import torch.utils.checkpoint
|
22 |
from einops import rearrange
|
23 |
from torch import nn
|
|
|
24 |
from transformers.activations import ACT2FN
|
25 |
+
|
|
|
|
|
|
|
|
|
26 |
from transformers.modeling_utils import PreTrainedModel
|
27 |
+
from transformers.utils import logging
|
28 |
+
|
29 |
+
from fused_norm_gate import FusedRMSNormSwishGate
|
30 |
+
|
|
|
|
|
|
|
|
|
|
|
31 |
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
|
32 |
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
33 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
34 |
+
|
|
|
35 |
from timm.models.layers import DropPath
|
36 |
|
37 |
compute_ARank = False # [ARank] Set this to True to compute attention rank
|
38 |
|
|
|
|
|
|
|
|
|
|
|
39 |
from .configuration_mmMamba_embedding import mmMambaEmbeddingConfig
|
40 |
|
|
|
|
|
41 |
from .configuration_mmMamba import mmMambaConfig
|
42 |
|
43 |
try:
|