File size: 3,693 Bytes
07f3d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
5f01803
07f3d5b
 
 
 
5f01803
07f3d5b
 
 
 
 
 
 
 
 
 
 
 
 
5f01803
07f3d5b
 
 
8e90038
21443c3
 
ba85e88
77389b9
8e90038
 
77389b9
8e90038
07f3d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a90639
4bb8e69
 
90152c0
 
 
 
4bb8e69
 
4a90639
 
6ee4f1f
 
4bb8e69
07f3d5b
4bb8e69
 
07f3d5b
4bb8e69
 
 
 
07f3d5b
 
 
4bb8e69
07f3d5b
4bb8e69
 
 
 
07f3d5b
d0186a4
 
07f3d5b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import requests

import json
import os

APIKEY = os.environ.get("APIKEY")
APISECRET = os.environ.get("APISECRET")

def predict(text, seed, out_seq_length, min_gen_length, sampling_strategy, 
    num_beams, length_penalty, no_repeat_ngram_size, 
    temperature, topk, topp):
    global APIKEY
    global APISECRET

    url = 'https://wudao.aminer.cn/os/api/api/v2/completions_130B'

    payload = json.dumps({
        "apikey": APIKEY,
        "apisecret": APISECRET ,
        "language": "zh-CN",
        "prompt": text,
        "length_penalty": length_penalty,
        "temperature": temperature,
        "top_k": topk,
        "top_p": topp,
        "min_gen_length": min_gen_length,
        "sampling_strategy": sampling_strategy,
        "num_beams": num_beams,
        "max_tokens": out_seq_length
    })

    headers = {
        'Content-Type': 'application/json'
    }

    response = requests.request("POST", url, headers=headers, data=payload)
    
    print(response.json())
    
    answer = response.json()['result']['output']['raw']
    if isinstance(answer, list):
        answer = answer[0]
    
    answer = answer.replace('[</s>]', '')
    return answer


if __name__ == "__main__":
    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # GLM-130B
            An Open Bilingual Pre-Trained Model
            """)

        with gr.Row():
            with gr.Column():
                model_input = gr.Textbox(lines=7, placeholder='Input something in English or Chinese', label='Input')
                with gr.Row():
                    gen = gr.Button("Generate")
                    clr = gr.Button("Clear")
                   
            outputs = gr.Textbox(lines=7, label='Output')
                
        gr.Markdown(
            """
            Generation Parameter
            """)
        with gr.Row():
            with gr.Column():
                seed = gr.Slider(maximum=100000, value=1234, step=1, label='Seed')
                out_seq_length = gr.Slider(maximum=256, value=128, minimum=32, step=1, label='Output Sequence Length')
            with gr.Column():
                min_gen_length = gr.Slider(maximum=64, value=0, step=1, label='Min Generate Length')
                sampling_strategy = gr.Radio(choices=['BeamSearchStrategy', 'BaseStrategy'], value='BeamSearchStrategy', label='Search Strategy')

        with gr.Row():
            with gr.Column():
                # beam search
                gr.Markdown(
                    """
                    Beam Search Parameter
                    """)
                num_beams = gr.Slider(maximum=4, value=1, minimum=1, step=1, label='Number of Beams')
                length_penalty = gr.Slider(maximum=1, value=0.8, minimum=0, label='Length Penalty')
                no_repeat_ngram_size = gr.Slider(maximum=5, value=3, minimum=1, step=1, label='No Repeat Ngram Size')
            with gr.Column():
                # base search
                gr.Markdown(
                    """
                    Base Search Parameter
                    """)
                temperature = gr.Slider(maximum=1, value=1, minimum=0, label='Temperature')
                topk = gr.Slider(maximum=8, value=1, minimum=0, step=1, label='Top K')
                topp = gr.Slider(maximum=1, value=0, minimum=0, label='Top P')
            
        inputs = [model_input, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp]
        gen.click(fn=predict, inputs=inputs, outputs=outputs)
        clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)

    demo.launch()