huolongguo10 commited on
Commit
6bea42c
·
1 Parent(s): d69ac3c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/web_demo.py
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ import chatglm_cpp
7
+ import gradio as gr
8
+
9
+ DEFAULT_MODEL_PATH = Path(__file__).resolve().parent.parent / "chatglm-ggml.bin"
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("-m", "--model", default=DEFAULT_MODEL_PATH, type=Path, help="model path")
13
+ parser.add_argument("--mode", default="chat", type=str, choices=["chat", "generate"], help="inference mode")
14
+ parser.add_argument("-l", "--max_length", default=2048, type=int, help="max total length including prompt and output")
15
+ parser.add_argument("-c", "--max_context_length", default=512, type=int, help="max context length")
16
+ parser.add_argument("--top_k", default=0, type=int, help="top-k sampling")
17
+ parser.add_argument("--top_p", default=0.7, type=float, help="top-p sampling")
18
+ parser.add_argument("--temp", default=0.95, type=float, help="temperature")
19
+ parser.add_argument("--repeat_penalty", default=1.0, type=float, help="penalize repeat sequence of tokens")
20
+ parser.add_argument("-t", "--threads", default=0, type=int, help="number of threads for inference")
21
+ parser.add_argument("--plain", action="store_true", help="display in plain text without markdown support")
22
+ args = parser.parse_args()
23
+
24
+ pipeline = chatglm_cpp.Pipeline(args.model)
25
+
26
+
27
+ def postprocess(text):
28
+ if args.plain:
29
+ return f"<pre>{text}</pre>"
30
+ return text
31
+
32
+
33
+ def predict(input, chatbot, max_length, top_p, temperature, history):
34
+ chatbot.append((postprocess(input), ""))
35
+ response = ""
36
+ history.append(input)
37
+
38
+ generation_kwargs = dict(
39
+ max_length=max_length,
40
+ max_context_length=args.max_context_length,
41
+ do_sample=temperature > 0,
42
+ top_k=args.top_k,
43
+ top_p=top_p,
44
+ temperature=temperature,
45
+ repetition_penalty=args.repeat_penalty,
46
+ num_threads=args.threads,
47
+ stream=True,
48
+ )
49
+ generator = (
50
+ pipeline.chat(history, **generation_kwargs)
51
+ if args.mode == "chat"
52
+ else pipeline.generate(input, **generation_kwargs)
53
+ )
54
+ for response_piece in generator:
55
+ response += response_piece
56
+ chatbot[-1] = (chatbot[-1][0], postprocess(response))
57
+
58
+ yield chatbot, history
59
+
60
+ history.append(response)
61
+ yield chatbot, history
62
+
63
+
64
+ def reset_user_input():
65
+ return gr.update(value="")
66
+
67
+
68
+ def reset_state():
69
+ return [], []
70
+
71
+
72
+ with gr.Blocks() as demo:
73
+ gr.HTML("""<h1 align="center">ChatGLM.cpp</h1>""")
74
+
75
+ chatbot = gr.Chatbot()
76
+ with gr.Row():
77
+ with gr.Column(scale=4):
78
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=8)
79
+ submitBtn = gr.Button("Submit", variant="primary")
80
+ with gr.Column(scale=1):
81
+ max_length = gr.Slider(0, 2048, value=args.max_length, step=1.0, label="Maximum Length", interactive=True)
82
+ top_p = gr.Slider(0, 1, value=args.top_p, step=0.01, label="Top P", interactive=True)
83
+ temperature = gr.Slider(0, 1, value=args.temp, step=0.01, label="Temperature", interactive=True)
84
+ emptyBtn = gr.Button("Clear History")
85
+
86
+ history = gr.State([])
87
+
88
+ submitBtn.click(
89
+ predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True
90
+ )
91
+ submitBtn.click(reset_user_input, [], [user_input])
92
+
93
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
94
+
95
+ demo.queue().launch(share=False, inbrowser=True)