mzltest commited on
Commit
64622dd
1 Parent(s): 4e22193

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +30 -222
generate.py CHANGED
@@ -1,222 +1,30 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import os
4
- import argparse
5
- from tqdm import trange
6
- from transformers import GPT2LMHeadModel
7
-
8
-
9
- def is_word(word):
10
- for item in list(word):
11
- if item not in 'qwertyuiopasdfghjklzxcvbnm':
12
- return False
13
- return True
14
-
15
-
16
- def _is_chinese_char(char):
17
- """Checks whether CP is the codepoint of a CJK character."""
18
- # This defines a "chinese character" as anything in the CJK Unicode block:
19
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
20
- #
21
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
22
- # despite its name. The modern Korean Hangul alphabet is a different block,
23
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
24
- # space-separated words, so they are not treated specially and handled
25
- # like the all of the other languages.
26
- cp = ord(char)
27
- if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
28
- (cp >= 0x3400 and cp <= 0x4DBF) or #
29
- (cp >= 0x20000 and cp <= 0x2A6DF) or #
30
- (cp >= 0x2A700 and cp <= 0x2B73F) or #
31
- (cp >= 0x2B740 and cp <= 0x2B81F) or #
32
- (cp >= 0x2B820 and cp <= 0x2CEAF) or
33
- (cp >= 0xF900 and cp <= 0xFAFF) or #
34
- (cp >= 0x2F800 and cp <= 0x2FA1F)): #
35
- return True
36
-
37
- return False
38
-
39
-
40
- def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
41
- """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
42
- Args:
43
- logits: logits distribution shape (vocabulary size)
44
- top_k > 0: keep only top k tokens with highest probability (top-k filtering).
45
- top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
46
- Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
47
- From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
48
- """
49
- assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
50
- top_k = min(top_k, logits.size(-1)) # Safety check
51
- if top_k > 0:
52
- # Remove all tokens with a probability less than the last token of the top-k
53
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
54
- logits[indices_to_remove] = filter_value
55
-
56
- if top_p > 0.0:
57
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
58
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
59
-
60
- # Remove tokens with cumulative probability above the threshold
61
- sorted_indices_to_remove = cumulative_probs > top_p
62
- # Shift the indices to the right to keep also the first token above the threshold
63
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
64
- sorted_indices_to_remove[..., 0] = 0
65
-
66
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
67
- logits[indices_to_remove] = filter_value
68
- return logits
69
-
70
-
71
- def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
72
- device='cpu'):
73
- context = torch.tensor(context, dtype=torch.long, device=device)
74
- context = context.unsqueeze(0)
75
- generated = context
76
- with torch.no_grad():
77
- for _ in trange(length):
78
- inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)}
79
- outputs = model(
80
- **inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
81
- next_token_logits = outputs[0][0, -1, :]
82
- for id in set(generated):
83
- next_token_logits[id] /= repitition_penalty
84
- next_token_logits = next_token_logits / temperature
85
- next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
86
- filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
87
- next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
88
- generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
89
- return generated.tolist()[0]
90
-
91
-
92
- def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'):
93
- inputs = torch.LongTensor(context).view(1, -1).to(device)
94
- if len(context) > 1:
95
- _, past = model(inputs[:, :-1], None)[:2]
96
- prev = inputs[:, -1].view(1, -1)
97
- else:
98
- past = None
99
- prev = inputs
100
- generate = [] + context
101
- with torch.no_grad():
102
- for i in trange(length):
103
- output = model(prev, past=past)
104
- output, past = output[:2]
105
- output = output[-1].squeeze(0) / temperature
106
- filtered_logits = top_k_top_p_filtering(output, top_k=top_k, top_p=top_p)
107
- next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
108
- generate.append(next_token.item())
109
- prev = next_token.view(1, 1)
110
- return generate
111
-
112
-
113
- # 通过命令行参数--fast_pattern,指定模式
114
- def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu',
115
- is_fast_pattern=False):
116
- if is_fast_pattern:
117
- return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p,
118
- device=device)
119
- else:
120
- return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p,
121
- repitition_penalty=repitition_penalty, device=device)
122
-
123
-
124
- def main():
125
- parser = argparse.ArgumentParser()
126
- parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='生成设备')
127
- parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度')
128
- parser.add_argument('--batch_size', default=1, type=int, required=False, help='生成的batch size')
129
- parser.add_argument('--nsamples', default=10, type=int, required=False, help='生成几个样本')
130
- parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度')
131
- parser.add_argument('--topk', default=8, type=int, required=False, help='最高几选一')
132
- parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
133
- parser.add_argument('--model_config', default='./model_config_small.json', type=str, required=False,
134
- help='模型参数')
135
- parser.add_argument('--tokenizer_path', default='./vocab_small.txt', type=str, required=False, help='词表路径')
136
- parser.add_argument('--model_path', default='./', type=str, required=False, help='模型路径')
137
- parser.add_argument('--prefix', default='萧炎', type=str, required=False, help='生成文章的开头')
138
- parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词')
139
- parser.add_argument('--segment', action='store_true', help='中文以词为单位')
140
- parser.add_argument('--fast_pattern', action='store_true', help='采用更加快的方式生成文本')
141
- parser.add_argument('--save_samples', action='store_true', help='保存产生的样本')
142
- parser.add_argument('--save_samples_path', default='.', type=str, required=False, help="保存样本的路径")
143
- parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False)
144
-
145
- args = parser.parse_args()
146
- print('args:\n' + args.__repr__())
147
-
148
- if args.segment:
149
- from tokenizations import tokenization_bert_word_level as tokenization_bert
150
- else:
151
- from tokenizations import tokenization_bert
152
-
153
- os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
154
- length = args.length
155
- batch_size = args.batch_size
156
- nsamples = args.nsamples
157
- temperature = args.temperature
158
- topk = args.topk
159
- topp = args.topp
160
- repetition_penalty = args.repetition_penalty
161
-
162
- device = "cuda" if torch.cuda.is_available() else "cpu"
163
-
164
- tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
165
- model = GPT2LMHeadModel.from_pretrained(args.model_path)
166
- model.to(device)
167
- model.eval()
168
-
169
- n_ctx = model.config.n_ctx
170
-
171
- if length == -1:
172
- length = model.config.n_ctx
173
- if args.save_samples:
174
- if not os.path.exists(args.save_samples_path):
175
- os.makedirs(args.save_samples_path)
176
- samples_file = open(args.save_samples_path + '/samples.txt', 'w', encoding='utf8')
177
- while True:
178
- raw_text = args.prefix
179
- context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
180
- generated = 0
181
- for _ in range(nsamples // batch_size):
182
- out = generate(
183
- n_ctx=n_ctx,
184
- model=model,
185
- context=context_tokens,
186
- length=length,
187
- is_fast_pattern=args.fast_pattern, tokenizer=tokenizer,
188
- temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device
189
- )
190
- for i in range(batch_size):
191
- generated += 1
192
- text = tokenizer.convert_ids_to_tokens(out)
193
- for i, item in enumerate(text[:-1]): # 确保英文前后有空格
194
- if is_word(item) and is_word(text[i + 1]):
195
- text[i] = item + ' '
196
- for i, item in enumerate(text):
197
- if item == '[MASK]':
198
- text[i] = ''
199
- elif item == '[CLS]':
200
- text[i] = '\n\n'
201
- elif item == '[SEP]':
202
- text[i] = '\n'
203
- info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n"
204
- print(info)
205
- text = ''.join(text).replace('##', '').strip()
206
- print(text)
207
- if args.save_samples:
208
- samples_file.write(info)
209
- samples_file.write(text)
210
- samples_file.write('\n')
211
- samples_file.write('=' * 90)
212
- samples_file.write('\n' * 2)
213
- print("=" * 80)
214
- if generated == nsamples:
215
- # close file when finish writing.
216
- if args.save_samples:
217
- samples_file.close()
218
- break
219
-
220
-
221
- if __name__ == '__main__':
222
- main()
 
1
+ import gradio as gr
2
+ import subprocess
3
+
4
+
5
+ def generate_text(length, prefix, temperature,topk,topp,rep):
6
+ # 构建命令行参数
7
+ my_prefix = "--prefix=" + prefix + ","
8
+ args = ["python", "generate.py", f"--length={int(length)}", f"--nsamples=1", f"--prefix={prefix}", f"--temperature={temperature}",f"--batch_size=1",f"--topk={int(topk)}",f"--topp={topp}",f"--repetition_penalty={rep}","--fast_pattern","--tokenizer_path=./vocab.txt","--model_config=./config.json"]
9
+
10
+ # 执行命令并获取输出
11
+ process = subprocess.Popen(args, stdout=subprocess.PIPE)
12
+ output, error = process.communicate()
13
+ output = output.decode("utf-8")
14
+
15
+ return output
16
+
17
+ input_length = gr.Slider(label="生成文本长度", minimum=10, maximum=500, value=500,step=10)
18
+ input_prefix = gr.Textbox(label="起始文本")
19
+ input_temperature = gr.Slider(label="生成温度", minimum=0, maximum=2, value=1, step=0.01)
20
+ #input_batchsize = gr.Slider(label="生成的batch size", minimum=1, maximum=1, value=1,step=1)
21
+ input_topk = gr.Slider(label="最高几选一", minimum=1, maximum=48, value=32, step=1)
22
+ input_topp = gr.Slider(label="最高积累概率", minimum=0, maximum=1, value=0,step=0.01)
23
+ input_repeat_penality = gr.Slider(label="重复罚值", minimum=0, maximum=15, value=10,step=0.01)
24
+
25
+ output_text = gr.Textbox(label="生成的文本")
26
+
27
+ title = "GPT2中文文本生成器"
28
+ description = "cpu推理约1s/字,温度太低基本是无意义字符"
29
+
30
+ gr.Interface(fn=generate_text, inputs=[input_length, input_prefix, input_temperature,input_topk,input_topp,input_repeat_penality], outputs=output_text, title=title, description=description).launch()