Spaces:
Running
Running
File size: 3,834 Bytes
d0e1f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
from torch import nn
import math
from pytorch_transformers.modeling_bert import(
BertEncoder,
BertPreTrainedModel,
BertConfig
)
class GeLU(nn.Module):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class mlp_meta(nn.Module):
def __init__(self, config):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(config.hid_dim, config.hid_dim),
GeLU(),
BertLayerNorm(config.hid_dim, eps=1e-12),
nn.Dropout(config.dropout),
)
def forward(self, x):
return self.mlp(x)
class Bert_Transformer_Layer(BertPreTrainedModel):
def __init__(self,fusion_config):
super().__init__(BertConfig(**fusion_config))
bertconfig_fusion = BertConfig(**fusion_config)
self.encoder = BertEncoder(bertconfig_fusion)
self.init_weights()
def forward(self,input, mask=None):
"""
input:(bs, 4, dim)
"""
batch, feats, dim = input.size()
if mask is not None:
mask_ = torch.ones(size=(batch,feats), device=mask.device)
mask_[:,1:] = mask
mask_ = torch.bmm(mask_.view(batch,1,-1).transpose(1,2), mask_.view(batch,1,-1))
mask_ = mask_.unsqueeze(1)
else:
mask = torch.Tensor([1.0]).to(input.device)
mask_ = mask.repeat(batch,1,feats, feats)
extend_mask = (1- mask_) * -10000
assert not extend_mask.requires_grad
head_mask = [None] * self.config.num_hidden_layers
enc_output = self.encoder(
input,extend_mask,head_mask=head_mask
)
output = enc_output[0]
all_attention = enc_output[1]
return output,all_attention
class mmdPreModel(nn.Module):
def __init__(self, config, num_mlp=0, transformer_flag=False, num_hidden_layers=1, mlp_flag=True):
super(mmdPreModel, self).__init__()
self.num_mlp = num_mlp
self.transformer_flag = transformer_flag
self.mlp_flag = mlp_flag
token_num = config.token_num
self.mlp = nn.Sequential(
nn.Linear(config.in_dim, config.hid_dim),
GeLU(),
BertLayerNorm(config.hid_dim, eps=1e-12),
nn.Dropout(config.dropout),
# nn.Linear(config.hid_dim, config.out_dim),
)
self.fusion_config = {
'hidden_size': config.in_dim,
'num_hidden_layers':num_hidden_layers,
'num_attention_heads':4,
'output_attentions':True
}
if self.num_mlp>0:
self.mlp2 = nn.ModuleList([mlp_meta(config) for _ in range(self.num_mlp)])
if self.transformer_flag:
self.transformer = Bert_Transformer_Layer(self.fusion_config)
self.feature = nn.Linear(config.hid_dim * token_num, config.out_dim)
def forward(self, features):
"""
input: [batch, token_num, hidden_size], output: [batch, token_num * config.out_dim]
"""
if self.transformer_flag:
features,_ = self.transformer(features)
if self.mlp_flag:
features = self.mlp(features)
if self.num_mlp>0:
# features = self.mlp2(features)
for _ in range(1):
for mlp in self.mlp2:
features = mlp(features)
features = self.feature(features.view(features.shape[0], -1))
return features #features.view(features.shape[0], -1)
|