Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from fairseq.data import Dictionary | |
from fairseq.models import ( | |
FairseqDecoder, | |
FairseqLanguageModel, | |
register_model, | |
register_model_architecture, | |
) | |
class DummyModel(FairseqLanguageModel): | |
def __init__(self, args, encoder): | |
super().__init__(encoder) | |
self.args = args | |
def add_args(parser): | |
parser.add_argument("--num-layers", type=int, default=24) | |
parser.add_argument("--embed-dim", type=int, default=1024) | |
def build_model(cls, args, task): | |
encoder = DummyEncoder( | |
num_embed=len(task.target_dictionary), | |
embed_dim=args.embed_dim, | |
num_layers=args.num_layers, | |
) | |
return cls(args, encoder) | |
def forward(self, src_tokens, masked_tokens=None, **kwargs): | |
return self.decoder(src_tokens, masked_tokens=masked_tokens) | |
class DummyEncoder(FairseqDecoder): | |
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24): | |
super().__init__(Dictionary()) | |
self.embed = nn.Embedding( | |
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0 | |
) | |
self.layers_a = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.LayerNorm(embed_dim), | |
nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection | |
nn.Linear(3 * embed_dim, embed_dim), # skip self-attention | |
nn.Linear(embed_dim, embed_dim), # output projection | |
nn.Dropout(), | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.layers_b = nn.ModuleList( | |
[ | |
nn.Sequential( | |
nn.LayerNorm(embed_dim), | |
nn.Linear(embed_dim, 4 * embed_dim), # FFN | |
nn.ReLU(), | |
nn.Linear(4 * embed_dim, embed_dim), # FFN | |
nn.Dropout(0.1), | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.out_proj = nn.Linear(embed_dim, num_embed) | |
def forward(self, tokens, masked_tokens=None): | |
x = self.embed(tokens) | |
for layer_a, layer_b in zip(self.layers_a, self.layers_b): | |
x = x + layer_a(x) | |
x = x + layer_b(x) | |
x = self.out_proj(x) | |
if masked_tokens is not None: | |
x = x[masked_tokens] | |
return (x,) | |
def max_positions(self): | |
return 1024 | |
def get_normalized_probs(self, net_output, log_probs, sample=None): | |
logits = net_output[0].float() | |
if log_probs: | |
return F.log_softmax(logits, dim=-1) | |
else: | |
return F.softmax(logits, dim=-1) | |
def base_architecture(args): | |
pass | |