File size: 1,623 Bytes
112f14c
ff094ba
 
112f14c
5c6d076
ff094ba
112f14c
c8f55c6
ff094ba
 
 
 
 
 
112f14c
 
449f209
9cf92f9
5c6d076
 
 
 
 
 
9cf92f9
112f14c
 
449f209
112f14c
5c6d076
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import gradio as gr
import subprocess


def generate_text(length, prefix, temperature,batchsize,topk,topp,rep):
    # 构建命令行参数
    my_prefix = "--prefix=" + prefix + ","
    args = ["python", "generate.py", f"--length={int(length)}", f"--nsamples=1", f"--prefix={prefix}", f"--temperature={temperature}",f"--batch_size={int(batchsize)}",f"--topk={int(topk)}",f"--topp={topp}",f"--repetition_penalty={rep}"]
    
    # 执行命令并获取输出
    process = subprocess.Popen(args, stdout=subprocess.PIPE)
    output, error = process.communicate()
    output = output.decode("utf-8")
    
    return output

input_length = gr.Slider(label="生成文本长度", minimum=10, maximum=500, default=100,step=10)
input_prefix = gr.Textbox(label="起始文本")
input_temperature = gr.Slider(label="生成温度", minimum=0, maximum=2, default=1, step=0.01)
input_batchsize = gr.Slider(label="生成的batch size", minimum=1, maximum=8, default=4,step=1)
input_topk = gr.Slider(label="最高几选一", minimum=1, maximum=20, default=8, step=1)
input_topp = gr.Slider(label="最高积累概率", minimum=0, maximum=2, default=0,step=0.01)
input_repeat_penality = gr.Slider(label="重复罚值", minimum=0, maximum=2, default=1,step=0.01)

output_text = gr.Textbox(label="生成的文本")

title = "GPT2中文文本生成器"
description = "cpu推理约1s/字,温度太低基本是无意义字符"

gr.Interface(fn=generate_text, inputs=[input_length, input_prefix, input_temperature,input_batchsize,input_topk,input_topp,input_repeat_penality], outputs=output_text, title=title, description=description).launch()