File size: 4,942 Bytes
78f739d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# Configuration
config = {
"learning_rate": 1e-4,
"batch_size": 32,
"vocab_size": 30522,
"max_len": 256,
"hidden_size": 768,
"dropout": 0.1,
"n_layer": 12,
"n_head": 12,
"ff_expansion_factor": 4,
"rnn_units": 768,
"num_labels": 5
}
class MyClass:
def __init__(self, value):
self.value = value
# Custom Initializer
def custom_initializer(shape):
return torch.normal(mean=0.0, std=0.02, size=shape)
class CustomEmbedding(nn.Module):
def __init__(self, vocab_size, hidden_size):
super(CustomEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size, _weight=custom_initializer((vocab_size, hidden_size)))
def forward(self, inputs):
return self.embedding(inputs)
class PositionalEncoding(nn.Module):
def __init__(self, n_embd, max_len=5000):
super(PositionalEncoding, self).__init__()
self.n_embd = n_embd
self.max_len = max_len
pe = torch.zeros(max_len, n_embd)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, n_embd, 2).float() * -(np.log(10000.0) / n_embd))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class MultiheadAttention(nn.Module):
def __init__(self, config):
super(MultiheadAttention, self).__init__()
self.attention = nn.MultiheadAttention(config['hidden_size'], config['n_head'], dropout=config['dropout'])
def forward(self, v, k, q, mask=None):
attn_output, attn_output_weights = self.attention(q, k, v, attn_mask=mask)
return attn_output
class FeedForward(nn.Module):
def __init__(self, config):
super(FeedForward, self).__init__()
self.dense1 = nn.Linear(config['hidden_size'], config['hidden_size'] * config['ff_expansion_factor'])
self.dense2 = nn.Linear(config['hidden_size'] * config['ff_expansion_factor'], config['hidden_size'])
self.dropout = nn.Dropout(config['dropout'])
def forward(self, x):
x = torch.nn.functional.gelu(self.dense1(x))
x = self.dropout(x)
return self.dense2(x)
class TransformerXLBlock(nn.Module):
def __init__(self, config):
super(TransformerXLBlock, self).__init__()
self.attn = MultiheadAttention(config)
self.ff = FeedForward(config)
self.ln1 = nn.LayerNorm(config['hidden_size'])
self.ln2 = nn.LayerNorm(config['hidden_size'])
def forward(self, x, mask=None):
attn_out = self.attn(v=x, k=x, q=x, mask=mask)
out1 = self.ln1(x + attn_out)
ff_out = self.ff(out1)
return self.ln2(out1 + ff_out)
class JudgeXL(nn.Module):
def __init__(self, config):
super(JudgeXL, self).__init__()
self.token_embedding = CustomEmbedding(config['vocab_size'], config['hidden_size'])
self.pos_encoding = PositionalEncoding(config['hidden_size'], config['max_len'])
self.transformer_blocks = nn.ModuleList([TransformerXLBlock(config) for _ in range(config['n_layer'])])
self.ln_f = nn.LayerNorm(config['hidden_size'])
self.rnn = nn.LSTM(config['hidden_size'], config['rnn_units'], num_layers=2, dropout=config['dropout'], bidirectional=True, batch_first=True)
self.fc = nn.Linear(config['rnn_units'] * 2, config['vocab_size']) # Adjusted to rnn_units * 2
def forward(self, x, mask=None):
x = self.token_embedding(x)
x = self.pos_encoding(x)
for block in self.transformer_blocks:
x = block(x, mask=mask)
x = self.ln_f(x)
x, _ = self.rnn(x)
x = self.fc(x)
return x
def generate(self, prompt, max_len=100):
self.eval()
input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids.to(device)
generated = input_ids
with torch.no_grad():
for _ in range(max_len):
outputs = self.forward(generated)
next_token_logits = outputs[:, :] # Adjusted indexing
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
generated = torch.cat((generated, next_token_id), dim=1)
if next_token_id.item() == self.tokenizer.sep_token_id:
break
generated_text = self.tokenizer.decode(generated[0], skip_special_tokens=True)
return generated_text
# Load the last saved model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = JudgeXL(config)
model = torch.load('C:/AIstuffing/Judge_XL-LLM/xl-llm_weights/judgeXL-LLm_wiki.pth', weights_only=False) |