File size: 2,221 Bytes
2b5b2f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from math import sqrt,log
import sys

#sys.path.append("../energy") # Messy

import torch
import torch.nn as nn
from torch.nn.functional import softmax,relu,linear
from common import PositionalEncoding
from hopfield import HopfieldLayer, HopfieldMHA, HopfieldReLU, HopfieldSoftmax

from torch.cuda.amp import autocast
import yaml
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutput

class BertEnergyConfig(PretrainedConfig):
    
    model_type = "bert_energy"
    
    def __init__(self, config=None, path=None, vocabulary_size=50, num_layers=12, num_heads=12, forward_memories=2048, embedding_dim=768, activation="relu",positional=True,  bias=True, tie_weights=True, alpha=1.0,
                 beta=1., layer_norm=1e-05, dropout=0.0, block_size=512, share_layers=False, compile=False, pad_idx=None, **kwargs):

        self.vocabulary_size = vocabulary_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.activation = activation
        self.positional = positional
        self.tie_weights = tie_weights
        self.bias = bias
        self.forward_memories = forward_memories
        self.embedding_dim = embedding_dim
        self.share_layers = share_layers
        self.alpha = alpha
        self.beta = beta
        self.layer_norm = float(layer_norm)
        self.dropout = dropout
        self.block_size = block_size
        self.compile = compile
        self.pad_idx = pad_idx

        if config is not None:
            for key,value in config.to_dict():
                if key.lower() in self.__dict__.keys():
                    print(key, file=sys.stderr)
                    setattr(self,key.lower(),value)

        elif path is not None:
            if path.endswith(".yaml"):
                with open(path) as istream:
                    config = yaml.safe_load(istream)
                    for key,value in config.items():
                        print(key)
                        if key.lower() in self.__dict__.keys():
                            setattr(self,key.lower(),value)
            else:
                raise NotImplementedError
        super().__init__(**kwargs)