|
|
|
import os |
|
from transformers import AutoTokenizer, GPT2Tokenizer |
|
|
|
from metaseq import checkpoint_utils |
|
from transformers import OPTForCausalLM |
|
import torch |
|
|
|
path = "./model" |
|
hf_path = "/home/patrick/facebook/opt-30b" |
|
|
|
|
|
vocab_file = os.path.join(path, "gpt2-vocab.json") |
|
merges_file = os.path.join(path, "gpt2-merges.txt") |
|
|
|
tokenizer = GPT2Tokenizer(vocab_file, merges_file) |
|
tokenizer.save_pretrained(path) |
|
|
|
checkpoint = checkpoint_utils.load_model_ensemble_and_task( |
|
[os.path.join(path, "restored.pt")], |
|
arg_overrides={ |
|
"vocab_filename": vocab_file, |
|
"merges_filename": merges_file, |
|
} |
|
) |
|
|
|
model = checkpoint[0][0].eval() |
|
model = model |
|
|
|
hf_model = OPTForCausalLM.from_pretrained(hf_path) |
|
|
|
|
|
def single_batch_forward_logits(prompts): |
|
input_ids = tokenizer(prompts, return_tensors="pt").input_ids |
|
input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1) |
|
input_ids = input_ids |
|
with torch.no_grad(): |
|
logits = model(input_ids)[0] |
|
return logits |
|
|
|
|
|
def forward_hf(prompts): |
|
input_ids = tokenizer(prompts, return_tensors="pt").input_ids |
|
input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1) |
|
input_ids = input_ids |
|
with torch.no_grad(): |
|
logits = hf_model(input_ids)[0] |
|
return logits |
|
|
|
prompts = [ |
|
"Today is a beautiful day and I want to", |
|
"In the city of", |
|
"Paris is the capital of France and", |
|
"Computers and mobile phones have taken", |
|
] |
|
|
|
print("Next word generation") |
|
for prompt in prompts: |
|
print("-------------") |
|
print(f"Prompt: {prompt}...\n") |
|
logits_fsq = single_batch_forward_logits(prompt) |
|
pred_next_token = torch.argmax(logits_fsq[0, -1], -1) |
|
next_token = tokenizer.convert_ids_to_tokens([pred_next_token]) |
|
next_token = next_token[0].replace("Ġ", "") |
|
print(f"Next word: {next_token}") |
|
print("-------------") |
|
logits = forward_hf(prompt) |
|
pred_next_token = torch.argmax(logits[0, -1], -1) |
|
next_token = tokenizer.convert_ids_to_tokens([pred_next_token]) |
|
next_token = next_token[0].replace("Ġ", "") |
|
print(f"Next word: {next_token}") |
|
print("-------------") |
|
|
|
|
|
print("Is equal:", torch.allclose(logits_fsq.cpu(), logits.cpu(), atol=1e-3)) |
|
|