Upload model
Browse files- config.json +13 -0
- huggingface_wrapper.py +58 -0
- model.safetensors +3 -0
config.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"METLModel"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "huggingface_wrapper.METLConfig",
|
7 |
+
"AutoModel": "huggingface_wrapper.METLModel"
|
8 |
+
},
|
9 |
+
"id": null,
|
10 |
+
"model_type": "METL",
|
11 |
+
"torch_dtype": "float32",
|
12 |
+
"transformers_version": "4.42.4"
|
13 |
+
}
|
huggingface_wrapper.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
2 |
+
import metl.models as models
|
3 |
+
from metl.main import *
|
4 |
+
from metl.test_cnn import MNIST_CNN
|
5 |
+
|
6 |
+
class METLConfig(PretrainedConfig):
|
7 |
+
IDENT_UUID_MAP = IDENT_UUID_MAP
|
8 |
+
UUID_URL_MAP = UUID_URL_MAP
|
9 |
+
model_type = "METL"
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
id:str = None,
|
14 |
+
**kwargs,
|
15 |
+
):
|
16 |
+
self.id = id
|
17 |
+
super().__init__(**kwargs)
|
18 |
+
|
19 |
+
# input_shape,
|
20 |
+
# num_layers,
|
21 |
+
# conv_kernel_sizes,
|
22 |
+
# conv_channel_sizes,
|
23 |
+
# pool_kernel_sizes
|
24 |
+
|
25 |
+
class METLModel(PreTrainedModel):
|
26 |
+
config_class = METLConfig
|
27 |
+
def __init__(self, config:METLConfig):
|
28 |
+
super().__init__(config)
|
29 |
+
self.model = MNIST_CNN(
|
30 |
+
input_shape=(28, 28),
|
31 |
+
num_layers = 4,
|
32 |
+
conv_kernel_sizes = [4, 3, 3, 3],
|
33 |
+
conv_channel_sizes = [(1, 16), (16, 32), (32, 64), (64, 128)],
|
34 |
+
pool_kernel_sizes = [2, 2, 2]
|
35 |
+
)
|
36 |
+
self.encoder = None
|
37 |
+
self.config = config
|
38 |
+
|
39 |
+
def forward(self, X, pdb_fn=None):
|
40 |
+
if pdb_fn:
|
41 |
+
return self.model(X, pdb_fn=pdb_fn)
|
42 |
+
return self.model(X)
|
43 |
+
|
44 |
+
def load_from_uuid(self, id):
|
45 |
+
if id:
|
46 |
+
id = id.lower()
|
47 |
+
assert id in self.config.UUID_URL_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
|
48 |
+
self.config.id = id
|
49 |
+
|
50 |
+
self.model, self.encoder = get_from_uuid(self.config.UUID_URL_MAP[self.config.id])
|
51 |
+
|
52 |
+
def load_from_indent(self, id):
|
53 |
+
if id:
|
54 |
+
id = id.lower()
|
55 |
+
assert id in self.config.IDENT_UUID_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
|
56 |
+
self.config.id = id
|
57 |
+
|
58 |
+
self.model, self.encoder = get_from_ident(self.config.IDENT_UUID_MAP[self.config.id])
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83f15e4d00e58747d8a38065dd176620918cb67ebef549ee8fbb0ef74a9b727f
|
3 |
+
size 389768
|