HongyuanTao commited on
Commit
64283aa
·
verified ·
1 Parent(s): f3e2a30

Update modeling_mmMamba_embedding.py

Browse files
Files changed (1) hide show
  1. modeling_mmMamba_embedding.py +7 -29
modeling_mmMamba_embedding.py CHANGED
@@ -14,52 +14,30 @@
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  import math
17
- import queue
18
- import threading
19
- import warnings
20
- from typing import List, Optional, Tuple, Union
21
- from functools import partial
22
 
23
  import torch
24
  import torch.nn.functional as F
25
  import torch.utils.checkpoint
26
  from einops import rearrange
27
  from torch import nn
28
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
  from transformers.activations import ACT2FN
30
- from transformers.modeling_outputs import (
31
- BaseModelOutputWithPast,
32
- CausalLMOutputWithPast,
33
- SequenceClassifierOutputWithPast,
34
- )
35
  from transformers.modeling_utils import PreTrainedModel
36
- from transformers.cache_utils import Cache
37
- from transformers.utils import (
38
- add_start_docstrings,
39
- add_start_docstrings_to_model_forward,
40
- logging,
41
- replace_return_docstrings,
42
- )
43
- from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
44
- import copy
45
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
46
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
47
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
48
- from transformers.cache_utils import Cache
49
- import time
50
  from timm.models.layers import DropPath
51
 
52
  compute_ARank = False # [ARank] Set this to True to compute attention rank
53
 
54
- try:
55
- from transformers.generation.streamers import BaseStreamer
56
- except: # noqa # pylint: disable=bare-except
57
- BaseStreamer = None
58
-
59
  from .configuration_mmMamba_embedding import mmMambaEmbeddingConfig
60
 
61
- import time
62
-
63
  from .configuration_mmMamba import mmMambaConfig
64
 
65
  try:
 
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  import math
17
+ from typing import Optional, Tuple
 
 
 
 
18
 
19
  import torch
20
  import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
  from einops import rearrange
23
  from torch import nn
 
24
  from transformers.activations import ACT2FN
25
+
 
 
 
 
26
  from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import logging
28
+
29
+ from fused_norm_gate import FusedRMSNormSwishGate
30
+
 
 
 
 
 
31
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
32
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
33
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
34
+
 
35
  from timm.models.layers import DropPath
36
 
37
  compute_ARank = False # [ARank] Set this to True to compute attention rank
38
 
 
 
 
 
 
39
  from .configuration_mmMamba_embedding import mmMambaEmbeddingConfig
40
 
 
 
41
  from .configuration_mmMamba import mmMambaConfig
42
 
43
  try: