jgpeters commited on
Commit
bb8b1b9
·
verified ·
1 Parent(s): 2928a79

Upload model

Browse files
Files changed (3) hide show
  1. config.json +13 -0
  2. huggingface_wrapper.py +58 -0
  3. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "METLModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "huggingface_wrapper.METLConfig",
7
+ "AutoModel": "huggingface_wrapper.METLModel"
8
+ },
9
+ "id": null,
10
+ "model_type": "METL",
11
+ "torch_dtype": "float32",
12
+ "transformers_version": "4.42.4"
13
+ }
huggingface_wrapper.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+ import metl.models as models
3
+ from metl.main import *
4
+ from metl.test_cnn import MNIST_CNN
5
+
6
+ class METLConfig(PretrainedConfig):
7
+ IDENT_UUID_MAP = IDENT_UUID_MAP
8
+ UUID_URL_MAP = UUID_URL_MAP
9
+ model_type = "METL"
10
+
11
+ def __init__(
12
+ self,
13
+ id:str = None,
14
+ **kwargs,
15
+ ):
16
+ self.id = id
17
+ super().__init__(**kwargs)
18
+
19
+ # input_shape,
20
+ # num_layers,
21
+ # conv_kernel_sizes,
22
+ # conv_channel_sizes,
23
+ # pool_kernel_sizes
24
+
25
+ class METLModel(PreTrainedModel):
26
+ config_class = METLConfig
27
+ def __init__(self, config:METLConfig):
28
+ super().__init__(config)
29
+ self.model = MNIST_CNN(
30
+ input_shape=(28, 28),
31
+ num_layers = 4,
32
+ conv_kernel_sizes = [4, 3, 3, 3],
33
+ conv_channel_sizes = [(1, 16), (16, 32), (32, 64), (64, 128)],
34
+ pool_kernel_sizes = [2, 2, 2]
35
+ )
36
+ self.encoder = None
37
+ self.config = config
38
+
39
+ def forward(self, X, pdb_fn=None):
40
+ if pdb_fn:
41
+ return self.model(X, pdb_fn=pdb_fn)
42
+ return self.model(X)
43
+
44
+ def load_from_uuid(self, id):
45
+ if id:
46
+ id = id.lower()
47
+ assert id in self.config.UUID_URL_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
48
+ self.config.id = id
49
+
50
+ self.model, self.encoder = get_from_uuid(self.config.UUID_URL_MAP[self.config.id])
51
+
52
+ def load_from_indent(self, id):
53
+ if id:
54
+ id = id.lower()
55
+ assert id in self.config.IDENT_UUID_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
56
+ self.config.id = id
57
+
58
+ self.model, self.encoder = get_from_ident(self.config.IDENT_UUID_MAP[self.config.id])
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83f15e4d00e58747d8a38065dd176620918cb67ebef549ee8fbb0ef74a9b727f
3
+ size 389768