Commit
·
8936db8
1
Parent(s):
4248343
layer ranames
Browse files
__pycache__/modeling_minimamba.cpython-312.pyc
CHANGED
Binary files a/__pycache__/modeling_minimamba.cpython-312.pyc and b/__pycache__/modeling_minimamba.cpython-312.pyc differ
|
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:395592f9d4d414560ff89dfc7ce95dd4bf258d66daac079b1672923e27d76270
|
3 |
+
size 3065241488
|
modeling_minimamba.py
CHANGED
@@ -52,7 +52,7 @@ class MiniMamba(PreTrainedModel):
|
|
52 |
# But Mamba2 does that internally if config.weight_tying == True.
|
53 |
|
54 |
# This is optional: store any device or dtype you might want
|
55 |
-
self.device_ = torch.
|
56 |
if isinstance(config.torch_dtype, str):
|
57 |
self.dtype_ = getattr(torch, config.torch_dtype)
|
58 |
else:
|
|
|
52 |
# But Mamba2 does that internally if config.weight_tying == True.
|
53 |
|
54 |
# This is optional: store any device or dtype you might want
|
55 |
+
self.device_ = 'cuda' if torch.cuda.is_available() else 'cpu'
|
56 |
if isinstance(config.torch_dtype, str):
|
57 |
self.dtype_ = getattr(torch, config.torch_dtype)
|
58 |
else:
|