Update generate.py
Browse files- generate.py +30 -222
generate.py
CHANGED
@@ -1,222 +1,30 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
32 |
-
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
33 |
-
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
34 |
-
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
35 |
-
return True
|
36 |
-
|
37 |
-
return False
|
38 |
-
|
39 |
-
|
40 |
-
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
41 |
-
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
42 |
-
Args:
|
43 |
-
logits: logits distribution shape (vocabulary size)
|
44 |
-
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
45 |
-
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
46 |
-
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
47 |
-
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
48 |
-
"""
|
49 |
-
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
50 |
-
top_k = min(top_k, logits.size(-1)) # Safety check
|
51 |
-
if top_k > 0:
|
52 |
-
# Remove all tokens with a probability less than the last token of the top-k
|
53 |
-
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
54 |
-
logits[indices_to_remove] = filter_value
|
55 |
-
|
56 |
-
if top_p > 0.0:
|
57 |
-
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
58 |
-
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
59 |
-
|
60 |
-
# Remove tokens with cumulative probability above the threshold
|
61 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
62 |
-
# Shift the indices to the right to keep also the first token above the threshold
|
63 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
64 |
-
sorted_indices_to_remove[..., 0] = 0
|
65 |
-
|
66 |
-
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
67 |
-
logits[indices_to_remove] = filter_value
|
68 |
-
return logits
|
69 |
-
|
70 |
-
|
71 |
-
def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
|
72 |
-
device='cpu'):
|
73 |
-
context = torch.tensor(context, dtype=torch.long, device=device)
|
74 |
-
context = context.unsqueeze(0)
|
75 |
-
generated = context
|
76 |
-
with torch.no_grad():
|
77 |
-
for _ in trange(length):
|
78 |
-
inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)}
|
79 |
-
outputs = model(
|
80 |
-
**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
81 |
-
next_token_logits = outputs[0][0, -1, :]
|
82 |
-
for id in set(generated):
|
83 |
-
next_token_logits[id] /= repitition_penalty
|
84 |
-
next_token_logits = next_token_logits / temperature
|
85 |
-
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
|
86 |
-
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
87 |
-
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
88 |
-
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
|
89 |
-
return generated.tolist()[0]
|
90 |
-
|
91 |
-
|
92 |
-
def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'):
|
93 |
-
inputs = torch.LongTensor(context).view(1, -1).to(device)
|
94 |
-
if len(context) > 1:
|
95 |
-
_, past = model(inputs[:, :-1], None)[:2]
|
96 |
-
prev = inputs[:, -1].view(1, -1)
|
97 |
-
else:
|
98 |
-
past = None
|
99 |
-
prev = inputs
|
100 |
-
generate = [] + context
|
101 |
-
with torch.no_grad():
|
102 |
-
for i in trange(length):
|
103 |
-
output = model(prev, past=past)
|
104 |
-
output, past = output[:2]
|
105 |
-
output = output[-1].squeeze(0) / temperature
|
106 |
-
filtered_logits = top_k_top_p_filtering(output, top_k=top_k, top_p=top_p)
|
107 |
-
next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
|
108 |
-
generate.append(next_token.item())
|
109 |
-
prev = next_token.view(1, 1)
|
110 |
-
return generate
|
111 |
-
|
112 |
-
|
113 |
-
# 通过命令行参数--fast_pattern,指定模式
|
114 |
-
def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu',
|
115 |
-
is_fast_pattern=False):
|
116 |
-
if is_fast_pattern:
|
117 |
-
return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p,
|
118 |
-
device=device)
|
119 |
-
else:
|
120 |
-
return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p,
|
121 |
-
repitition_penalty=repitition_penalty, device=device)
|
122 |
-
|
123 |
-
|
124 |
-
def main():
|
125 |
-
parser = argparse.ArgumentParser()
|
126 |
-
parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='生成设备')
|
127 |
-
parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度')
|
128 |
-
parser.add_argument('--batch_size', default=1, type=int, required=False, help='生成的batch size')
|
129 |
-
parser.add_argument('--nsamples', default=10, type=int, required=False, help='生成几个样本')
|
130 |
-
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度')
|
131 |
-
parser.add_argument('--topk', default=8, type=int, required=False, help='最高几选一')
|
132 |
-
parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
|
133 |
-
parser.add_argument('--model_config', default='./model_config_small.json', type=str, required=False,
|
134 |
-
help='模型参数')
|
135 |
-
parser.add_argument('--tokenizer_path', default='./vocab_small.txt', type=str, required=False, help='词表路径')
|
136 |
-
parser.add_argument('--model_path', default='./', type=str, required=False, help='模型路径')
|
137 |
-
parser.add_argument('--prefix', default='萧炎', type=str, required=False, help='生成文章的开头')
|
138 |
-
parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词')
|
139 |
-
parser.add_argument('--segment', action='store_true', help='中文以词为单位')
|
140 |
-
parser.add_argument('--fast_pattern', action='store_true', help='采用更加快的方式生成文本')
|
141 |
-
parser.add_argument('--save_samples', action='store_true', help='保存产生的样本')
|
142 |
-
parser.add_argument('--save_samples_path', default='.', type=str, required=False, help="保存样本的路径")
|
143 |
-
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False)
|
144 |
-
|
145 |
-
args = parser.parse_args()
|
146 |
-
print('args:\n' + args.__repr__())
|
147 |
-
|
148 |
-
if args.segment:
|
149 |
-
from tokenizations import tokenization_bert_word_level as tokenization_bert
|
150 |
-
else:
|
151 |
-
from tokenizations import tokenization_bert
|
152 |
-
|
153 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
|
154 |
-
length = args.length
|
155 |
-
batch_size = args.batch_size
|
156 |
-
nsamples = args.nsamples
|
157 |
-
temperature = args.temperature
|
158 |
-
topk = args.topk
|
159 |
-
topp = args.topp
|
160 |
-
repetition_penalty = args.repetition_penalty
|
161 |
-
|
162 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
163 |
-
|
164 |
-
tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
|
165 |
-
model = GPT2LMHeadModel.from_pretrained(args.model_path)
|
166 |
-
model.to(device)
|
167 |
-
model.eval()
|
168 |
-
|
169 |
-
n_ctx = model.config.n_ctx
|
170 |
-
|
171 |
-
if length == -1:
|
172 |
-
length = model.config.n_ctx
|
173 |
-
if args.save_samples:
|
174 |
-
if not os.path.exists(args.save_samples_path):
|
175 |
-
os.makedirs(args.save_samples_path)
|
176 |
-
samples_file = open(args.save_samples_path + '/samples.txt', 'w', encoding='utf8')
|
177 |
-
while True:
|
178 |
-
raw_text = args.prefix
|
179 |
-
context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
|
180 |
-
generated = 0
|
181 |
-
for _ in range(nsamples // batch_size):
|
182 |
-
out = generate(
|
183 |
-
n_ctx=n_ctx,
|
184 |
-
model=model,
|
185 |
-
context=context_tokens,
|
186 |
-
length=length,
|
187 |
-
is_fast_pattern=args.fast_pattern, tokenizer=tokenizer,
|
188 |
-
temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device
|
189 |
-
)
|
190 |
-
for i in range(batch_size):
|
191 |
-
generated += 1
|
192 |
-
text = tokenizer.convert_ids_to_tokens(out)
|
193 |
-
for i, item in enumerate(text[:-1]): # 确保英文前后有空格
|
194 |
-
if is_word(item) and is_word(text[i + 1]):
|
195 |
-
text[i] = item + ' '
|
196 |
-
for i, item in enumerate(text):
|
197 |
-
if item == '[MASK]':
|
198 |
-
text[i] = ''
|
199 |
-
elif item == '[CLS]':
|
200 |
-
text[i] = '\n\n'
|
201 |
-
elif item == '[SEP]':
|
202 |
-
text[i] = '\n'
|
203 |
-
info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n"
|
204 |
-
print(info)
|
205 |
-
text = ''.join(text).replace('##', '').strip()
|
206 |
-
print(text)
|
207 |
-
if args.save_samples:
|
208 |
-
samples_file.write(info)
|
209 |
-
samples_file.write(text)
|
210 |
-
samples_file.write('\n')
|
211 |
-
samples_file.write('=' * 90)
|
212 |
-
samples_file.write('\n' * 2)
|
213 |
-
print("=" * 80)
|
214 |
-
if generated == nsamples:
|
215 |
-
# close file when finish writing.
|
216 |
-
if args.save_samples:
|
217 |
-
samples_file.close()
|
218 |
-
break
|
219 |
-
|
220 |
-
|
221 |
-
if __name__ == '__main__':
|
222 |
-
main()
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import subprocess
|
3 |
+
|
4 |
+
|
5 |
+
def generate_text(length, prefix, temperature,topk,topp,rep):
|
6 |
+
# 构建命令行参数
|
7 |
+
my_prefix = "--prefix=" + prefix + ","
|
8 |
+
args = ["python", "generate.py", f"--length={int(length)}", f"--nsamples=1", f"--prefix={prefix}", f"--temperature={temperature}",f"--batch_size=1",f"--topk={int(topk)}",f"--topp={topp}",f"--repetition_penalty={rep}","--fast_pattern","--tokenizer_path=./vocab.txt","--model_config=./config.json"]
|
9 |
+
|
10 |
+
# 执行命令并获取输出
|
11 |
+
process = subprocess.Popen(args, stdout=subprocess.PIPE)
|
12 |
+
output, error = process.communicate()
|
13 |
+
output = output.decode("utf-8")
|
14 |
+
|
15 |
+
return output
|
16 |
+
|
17 |
+
input_length = gr.Slider(label="生成文本长度", minimum=10, maximum=500, value=500,step=10)
|
18 |
+
input_prefix = gr.Textbox(label="起始文本")
|
19 |
+
input_temperature = gr.Slider(label="生成温度", minimum=0, maximum=2, value=1, step=0.01)
|
20 |
+
#input_batchsize = gr.Slider(label="生成的batch size", minimum=1, maximum=1, value=1,step=1)
|
21 |
+
input_topk = gr.Slider(label="最高几选一", minimum=1, maximum=48, value=32, step=1)
|
22 |
+
input_topp = gr.Slider(label="最高积累概率", minimum=0, maximum=1, value=0,step=0.01)
|
23 |
+
input_repeat_penality = gr.Slider(label="重复罚值", minimum=0, maximum=15, value=10,step=0.01)
|
24 |
+
|
25 |
+
output_text = gr.Textbox(label="生成的文本")
|
26 |
+
|
27 |
+
title = "GPT2中文文本生成器"
|
28 |
+
description = "cpu推理约1s/字,温度太低基本是无意义字符"
|
29 |
+
|
30 |
+
gr.Interface(fn=generate_text, inputs=[input_length, input_prefix, input_temperature,input_topk,input_topp,input_repeat_penality], outputs=output_text, title=title, description=description).launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|