|
import os |
|
from contextlib import nullcontext |
|
import torch |
|
import tiktoken |
|
from model import GPTConfig, GPT |
|
|
|
|
|
init_from = 'resume' |
|
out_dir = 'out-stinfo' |
|
start = "\n" |
|
num_samples = 1 |
|
max_new_tokens = 100 |
|
temperature = 0.6 |
|
top_k = 200 |
|
seed = 1337 |
|
device = 'cpu' |
|
dtype = 'float16' |
|
|
|
def infer(): |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
device_type = 'cuda' if 'cuda' in device else 'cpu' |
|
ptdtype = { |
|
'float32': torch.float32, |
|
'bfloat16': torch.bfloat16, |
|
'float16': torch.float16 |
|
}[dtype] |
|
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( |
|
device_type=device_type, dtype=ptdtype) |
|
|
|
|
|
ckpt_path = 'ckpt.pt' |
|
checkpoint = torch.load(ckpt_path, map_location=device) |
|
gptconf = GPTConfig(**checkpoint['model_args']) |
|
model = GPT(gptconf) |
|
state_dict = checkpoint['model'] |
|
unwanted_prefix = '_orig_mod.' |
|
for k, v in list(state_dict.items()): |
|
if k.startswith(unwanted_prefix): |
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
model.load_state_dict(state_dict) |
|
|
|
model.eval() |
|
model.to(device) |
|
|
|
enc = tiktoken.get_encoding("gpt2") |
|
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) |
|
decode = lambda l: enc.decode(l) |
|
|
|
start_ids = encode(start) |
|
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
|
with torch.no_grad(): |
|
with ctx: |
|
for k in range(num_samples): |
|
y = model.generate(x, |
|
max_new_tokens, |
|
temperature=temperature, |
|
top_k=top_k) |
|
return decode(y[0].tolist()) |