tingxinli commited on
Commit
3aaecca
·
1 Parent(s): 869130c

Upload hideAndSeek.py

Browse files
Files changed (1) hide show
  1. hideAndSeek.py +50 -0
hideAndSeek.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
4
+ import openai
5
+ from openai import OpenAI
6
+
7
+ def hide_encrypt(original_input, hide_model, tokenizer):
8
+ hide_template = """<s>Paraphrase the text:%s\n\n"""
9
+ input_text = hide_template % original_input
10
+ inputs = tokenizer(input_text, return_tensors='pt').to(hide_model.device)
11
+ pred = hide_model.generate(
12
+ **inputs,
13
+ generation_config=GenerationConfig(
14
+ max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3),
15
+ do_sample=False,
16
+ num_beams=3,
17
+ repetition_penalty=5.0,
18
+ ),
19
+ )
20
+ pred = pred.cpu()[0][len(inputs['input_ids'][0]):]
21
+ hide_input = tokenizer.decode(pred, skip_special_tokens=True)
22
+ return hide_input
23
+
24
+ def seek_decrypt(hide_input, hide_output, original_input, seek_model, tokenizer):
25
+ seek_template = """Convert the text:\n%s\n\n%s\n\nConvert the text:\n%s\n\n"""
26
+ input_text = seek_template % (hide_input, hide_output, original_input)
27
+ inputs = tokenizer(input_text, return_tensors='pt').to(seek_model.device)
28
+ pred = seek_model.generate(
29
+ **inputs,
30
+ generation_config=GenerationConfig(
31
+ max_new_tokens = int(len(inputs['input_ids'][0]) * 1.3),
32
+ do_sample=False,
33
+ num_beams=3,
34
+ ),
35
+ )
36
+ pred = pred.cpu()[0][len(inputs['input_ids'][0]):]
37
+ original_output = tokenizer.decode(pred, skip_special_tokens=True)
38
+ return original_output
39
+
40
+ def get_gpt_output(prompt, api_key=None):
41
+ if not api_key:
42
+ raise ValueError('an open api key is needed for this function')
43
+ client = OpenAI(api_key=api_key)
44
+ completion = client.chat.completions.create(
45
+ model="gpt-3.5-turbo",
46
+ messages=[
47
+ {"role": "user", "content": prompt}
48
+ ]
49
+ )
50
+ return completion.choices[0].message.content