from transformers import GPT2Tokenizer, GPT2LMHeadModel import torch tokenizer = AutoTokenizer.from_pretrained("af1tang/personaGPT") model = AutoModelForCausalLM.from_pretrained("af1tang/personaGPT") if torch.cuda.is_available(): model = model.cuda() ## utility functions ## flatten = lambda l: [item for sublist in l for item in sublist] def to_data(x): if torch.cuda.is_available(): x = x.cpu() return x.data.numpy() def to_var(x): if not torch.is_tensor(x): x = torch.Tensor(x) if torch.cuda.is_available(): x = x.cuda() return x def display_dialog_history(dialog_hx): for j, line in enumerate(dialog_hx): msg = tokenizer.decode(line) if j %2 == 0: print(">> User: "+ msg) else: print("Bot: "+msg) print() def generate_next(bot_input_ids, do_sample=True, top_k=10, top_p=.92, max_length=1000, pad_token=tokenizer.eos_token_id): full_msg = model.generate(bot_input_ids, do_sample=True, top_k=top_k, top_p=top_p, max_length=max_length, pad_token_id=tokenizer.eos_token_id) msg = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:] return msg # get personality facts for conversation personas = [] for i in range(3): response = input(">> Fact %d: "%(i+1))+ tokenizer.eos_token personas.append(response) personas = tokenizer.encode(''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>']))