siroleo0 commited on
Commit
6798aed
·
1 Parent(s): fe83052

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
2
+ import torch
3
+ tokenizer = AutoTokenizer.from_pretrained("af1tang/personaGPT")
4
+ model = AutoModelForCausalLM.from_pretrained("af1tang/personaGPT")
5
+ if torch.cuda.is_available():
6
+ model = model.cuda()
7
+ ## utility functions ##
8
+ flatten = lambda l: [item for sublist in l for item in sublist]
9
+
10
+ def to_data(x):
11
+ if torch.cuda.is_available():
12
+ x = x.cpu()
13
+ return x.data.numpy()
14
+
15
+ def to_var(x):
16
+ if not torch.is_tensor(x):
17
+ x = torch.Tensor(x)
18
+ if torch.cuda.is_available():
19
+ x = x.cuda()
20
+ return x
21
+
22
+ def display_dialog_history(dialog_hx):
23
+ for j, line in enumerate(dialog_hx):
24
+ msg = tokenizer.decode(line)
25
+ if j %2 == 0:
26
+ print(">> User: "+ msg)
27
+ else:
28
+ print("Bot: "+msg)
29
+ print()
30
+
31
+ def generate_next(bot_input_ids, do_sample=True, top_k=10, top_p=.92,
32
+ max_length=1000, pad_token=tokenizer.eos_token_id):
33
+ full_msg = model.generate(bot_input_ids, do_sample=True,
34
+ top_k=top_k, top_p=top_p,
35
+ max_length=max_length, pad_token_id=tokenizer.eos_token_id)
36
+ msg = to_data(full_msg.detach()[0])[bot_input_ids.shape[-1]:]
37
+ return msg
38
+
39
+
40
+
41
+
42
+
43
+ # get personality facts for conversation
44
+ personas = []
45
+ for i in range(3):
46
+ response = input(">> Fact %d: "%(i+1))+ tokenizer.eos_token
47
+ personas.append(response)
48
+ personas = tokenizer.encode(''.join(['<|p2|>'] + personas + ['<|sep|>'] + ['<|start|>']))
49
+
50
+
51
+