|
|
|
config = {
|
|
"learning_rate": 1e-4,
|
|
"batch_size": 32,
|
|
"vocab_size": 30522,
|
|
"max_len": 256,
|
|
"hidden_size": 768,
|
|
"dropout": 0.1,
|
|
"n_layer": 12,
|
|
"n_head": 12,
|
|
"ff_expansion_factor": 4,
|
|
"rnn_units": 768,
|
|
"num_labels": 5
|
|
}
|
|
|
|
class MyClass:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
|
|
def custom_initializer(shape):
|
|
return torch.normal(mean=0.0, std=0.02, size=shape)
|
|
|
|
class CustomEmbedding(nn.Module):
|
|
def __init__(self, vocab_size, hidden_size):
|
|
super(CustomEmbedding, self).__init__()
|
|
self.embedding = nn.Embedding(vocab_size, hidden_size, _weight=custom_initializer((vocab_size, hidden_size)))
|
|
|
|
def forward(self, inputs):
|
|
return self.embedding(inputs)
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def __init__(self, n_embd, max_len=5000):
|
|
super(PositionalEncoding, self).__init__()
|
|
self.n_embd = n_embd
|
|
self.max_len = max_len
|
|
|
|
pe = torch.zeros(max_len, n_embd)
|
|
position = torch.arange(0, max_len).unsqueeze(1).float()
|
|
div_term = torch.exp(torch.arange(0, n_embd, 2).float() * -(np.log(10000.0) / n_embd))
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0).transpose(0, 1)
|
|
self.register_buffer('pe', pe)
|
|
|
|
def forward(self, x):
|
|
return x + self.pe[:x.size(0), :]
|
|
|
|
class MultiheadAttention(nn.Module):
|
|
def __init__(self, config):
|
|
super(MultiheadAttention, self).__init__()
|
|
self.attention = nn.MultiheadAttention(config['hidden_size'], config['n_head'], dropout=config['dropout'])
|
|
|
|
def forward(self, v, k, q, mask=None):
|
|
attn_output, attn_output_weights = self.attention(q, k, v, attn_mask=mask)
|
|
return attn_output
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, config):
|
|
super(FeedForward, self).__init__()
|
|
self.dense1 = nn.Linear(config['hidden_size'], config['hidden_size'] * config['ff_expansion_factor'])
|
|
self.dense2 = nn.Linear(config['hidden_size'] * config['ff_expansion_factor'], config['hidden_size'])
|
|
self.dropout = nn.Dropout(config['dropout'])
|
|
|
|
def forward(self, x):
|
|
x = torch.nn.functional.gelu(self.dense1(x))
|
|
x = self.dropout(x)
|
|
return self.dense2(x)
|
|
|
|
class TransformerXLBlock(nn.Module):
|
|
def __init__(self, config):
|
|
super(TransformerXLBlock, self).__init__()
|
|
self.attn = MultiheadAttention(config)
|
|
self.ff = FeedForward(config)
|
|
self.ln1 = nn.LayerNorm(config['hidden_size'])
|
|
self.ln2 = nn.LayerNorm(config['hidden_size'])
|
|
|
|
def forward(self, x, mask=None):
|
|
attn_out = self.attn(v=x, k=x, q=x, mask=mask)
|
|
out1 = self.ln1(x + attn_out)
|
|
ff_out = self.ff(out1)
|
|
return self.ln2(out1 + ff_out)
|
|
|
|
class JudgeXL(nn.Module):
|
|
def __init__(self, config):
|
|
super(JudgeXL, self).__init__()
|
|
self.token_embedding = CustomEmbedding(config['vocab_size'], config['hidden_size'])
|
|
self.pos_encoding = PositionalEncoding(config['hidden_size'], config['max_len'])
|
|
self.transformer_blocks = nn.ModuleList([TransformerXLBlock(config) for _ in range(config['n_layer'])])
|
|
self.ln_f = nn.LayerNorm(config['hidden_size'])
|
|
self.rnn = nn.LSTM(config['hidden_size'], config['rnn_units'], num_layers=2, dropout=config['dropout'], bidirectional=True, batch_first=True)
|
|
self.fc = nn.Linear(config['rnn_units'] * 2, config['vocab_size'])
|
|
|
|
def forward(self, x, mask=None):
|
|
x = self.token_embedding(x)
|
|
x = self.pos_encoding(x)
|
|
for block in self.transformer_blocks:
|
|
x = block(x, mask=mask)
|
|
x = self.ln_f(x)
|
|
x, _ = self.rnn(x)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
def generate(self, prompt, max_len=100):
|
|
self.eval()
|
|
input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids.to(device)
|
|
generated = input_ids
|
|
with torch.no_grad():
|
|
for _ in range(max_len):
|
|
outputs = self.forward(generated)
|
|
next_token_logits = outputs[:, :]
|
|
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
|
|
generated = torch.cat((generated, next_token_id), dim=1)
|
|
if next_token_id.item() == self.tokenizer.sep_token_id:
|
|
break
|
|
generated_text = self.tokenizer.decode(generated[0], skip_special_tokens=True)
|
|
return generated_text
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = JudgeXL(config)
|
|
model = torch.load('C:/AIstuffing/Judge_XL-LLM/xl-llm_weights/judgeXL-LLm_wiki.pth', weights_only=False) |