Spaces:
Sleeping
Sleeping
File size: 3,802 Bytes
8e41ab0 9a4dd2c 8e41ab0 9a4dd2c 8e41ab0 9a4dd2c 8e41ab0 b1c38c2 8e41ab0 9a4dd2c b1c38c2 9a4dd2c b1c38c2 8e41ab0 9a4dd2c 8e41ab0 b1c38c2 9a4dd2c 8e41ab0 9a4dd2c 8e41ab0 b1c38c2 9a4dd2c b1c38c2 9a4dd2c 8e41ab0 b1c38c2 9a4dd2c b1c38c2 9a4dd2c b1c38c2 8e41ab0 9a4dd2c 8e41ab0 9a4dd2c 8e41ab0 b1c38c2 8e41ab0 9a4dd2c 8e41ab0 b1c38c2 8e41ab0 b1c38c2 8e41ab0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import torch
from torch import nn
class Transformer(nn.Module):
"""
Transformer model for sequence-to-sequence tasks.
"""
def __init__(
self,
embedding_size,
src_vocab_size,
trg_vocab_size,
src_pad_idx,
num_heads,
num_encoder_layers,
num_decoder_layers,
dropout,
max_len,
device,
):
"""
Initializes the Transformer model.
Args:
embedding_size: Size of the embeddings.
src_vocab_size: Size of the source vocabulary.
trg_vocab_size: Size of the target vocabulary.
src_pad_idx: Index of the padding token in the source vocabulary.
num_heads: Number of attention heads.
num_encoder_layers: Number of encoder layers.
num_decoder_layers: Number of decoder layers.
dropout: Dropout probability.
max_len: Maximum sequence length.
device: Device to place tensors on.
"""
super(Transformer, self).__init__()
# Embeddings for source and target sequences
self.src_embeddings = nn.Embedding(src_vocab_size, embedding_size)
self.src_positional_embeddings = nn.Embedding(max_len, embedding_size)
self.trg_embeddings = nn.Embedding(trg_vocab_size, embedding_size)
self.trg_positional_embeddings = nn.Embedding(max_len, embedding_size)
self.device = device
# Transformer layer
self.transformer = nn.Transformer(
embedding_size,
num_heads,
num_encoder_layers,
num_decoder_layers,
)
# Final fully connected layer
self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
self.dropout = nn.Dropout(dropout)
self.src_pad_idx = src_pad_idx
def make_src_mask(self, src):
"""
Creates a mask to ignore padding tokens in the source sequence.
Args:
src: Source sequence tensor.
Returns:
src_mask: Mask tensor.
"""
src_mask = src.transpose(0, 1) == self.src_pad_idx
return src_mask.to(self.device)
def forward(self, src, trg):
"""
Forward pass of the Transformer model.
Args:
src: Source sequence tensor.
trg: Target sequence tensor.
Returns:
out: Output tensor.
"""
src_seq_length, S = src.shape
trg_seq_length, S = trg.shape
# Generate position indices for source and target sequences
src_positions = (
torch.arange(0, src_seq_length)
.unsqueeze(1)
.expand(src_seq_length, S)
.to(self.device)
)
trg_positions = (
torch.arange(0, trg_seq_length)
.unsqueeze(1)
.expand(trg_seq_length, S)
.to(self.device)
)
# Apply embeddings and dropout for source and target sequences
embed_src = self.dropout(
(self.src_embeddings(src) + self.src_positional_embeddings(src_positions))
)
embed_trg = self.dropout(
(self.trg_embeddings(trg) + self.trg_positional_embeddings(trg_positions))
)
# Generate masks for source padding and target sequences
src_padding_mask = self.make_src_mask(src)
trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
self.device
)
# Forward pass through Transformer
out = self.transformer(
embed_src,
embed_trg,
src_key_padding_mask=src_padding_mask,
tgt_mask=trg_mask,
)
# Apply final fully connected layer
out = self.fc_out(out)
return out
|