fracapuano commited on
Commit
65f400e
1 Parent(s): 8bf2dea

Add EEGViT model

Browse files
Files changed (3) hide show
  1. config.json +25 -0
  2. eegvit_model.py +48 -0
  3. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "EEGViTAutoModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "auto_map": {
7
+ "AutoModel": "eegvit_model.EEGViTAutoModel"
8
+ },
9
+ "encoder_stride": 16,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.0,
12
+ "hidden_size": 768,
13
+ "image_size": 224,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "layer_norm_eps": 1e-12,
17
+ "model_type": "vit",
18
+ "num_attention_heads": 12,
19
+ "num_channels": 3,
20
+ "num_hidden_layers": 12,
21
+ "patch_size": 16,
22
+ "qkv_bias": true,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.46.1"
25
+ }
eegvit_model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from transformers import ViTModel
4
+ import torch
5
+ from torch import nn
6
+ import transformers
7
+ from transformers import PreTrainedModel
8
+
9
+ class EEGViTAutoModel(PreTrainedModel):
10
+ config_class = transformers.ViTConfig
11
+
12
+ def __init__(self, config=None):
13
+ if config is None:
14
+ config = transformers.ViTConfig()
15
+ super().__init__(config)
16
+ self.model = EEGViT_pretrained()
17
+
18
+ class EEGViT_pretrained(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.conv1 = nn.Conv2d(
22
+ in_channels=1,
23
+ out_channels=256,
24
+ kernel_size=(1, 36),
25
+ stride=(1, 36),
26
+ padding=(0,2),
27
+ bias=False
28
+ )
29
+ self.batchnorm1 = nn.BatchNorm2d(256, False)
30
+ model_name = "google/vit-base-patch16-224"
31
+ config = transformers.ViTConfig.from_pretrained(model_name)
32
+ config.update({'num_channels': 256})
33
+ config.update({'image_size': (129,14)})
34
+ config.update({'patch_size': (8,1)})
35
+
36
+ model = transformers.ViTForImageClassification.from_pretrained(model_name, config=config, ignore_mismatched_sizes=True)
37
+ model.vit.embeddings.patch_embeddings.projection = torch.nn.Conv2d(256, 768, kernel_size=(8, 1), stride=(8, 1), padding=(0,0), groups=256)
38
+ model.classifier=torch.nn.Sequential(torch.nn.Linear(768,1000,bias=True),
39
+ torch.nn.Dropout(p=0.1),
40
+ torch.nn.Linear(1000,2,bias=True))
41
+ self.ViT = model
42
+
43
+ def forward(self,x):
44
+ x=self.conv1(x)
45
+ x=self.batchnorm1(x)
46
+ x=self.ViT.forward(x).logits
47
+
48
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7147906aa1f15600d12f0799c6cf0117ffc407c017d74817977eaff9c04ad91e
3
+ size 344096872