Hristo ZHANG 张鹤立 commited on
Commit
8f263fc
·
1 Parent(s): cc4e355
Files changed (1) hide show
  1. app.py +8 -23
app.py CHANGED
@@ -10,7 +10,7 @@ DEFAULT_MODEL_PATH = model_file
10
  parser = argparse.ArgumentParser()
11
  parser.add_argument("-m", "--model", default=DEFAULT_MODEL_PATH, type=Path, help="model path")
12
  parser.add_argument("--mode", default="chat", type=str, choices=["chat", "generate"], help="inference mode")
13
- parser.add_argument("-l", "--max_length", default=2048, type=int, help="max total length including prompt and output")
14
  parser.add_argument("-c", "--max_context_length", default=512, type=int, help="max context length")
15
  parser.add_argument("--top_k", default=0, type=int, help="top-k sampling")
16
  parser.add_argument("--top_p", default=0.7, type=float, help="top-p sampling")
@@ -25,27 +25,14 @@ llm = Llama(model_path=model_file)
25
 
26
 
27
 
28
- def predict(input, system_prompt, chatbot, max_length, ctx_length, top_p, temperature, history):
29
  chatbot.append((input, ""))
30
  response = ""
31
  history.append(input)
32
 
33
- generation_kwargs = dict(
34
- max_length=max_length,
35
- max_context_length=ctx_length,
36
- do_sample=temperature > 0,
37
- top_k=40,
38
- top_p=top_p,
39
- temperature=temperature,
40
- repetition_penalty=1.1,
41
- num_threads=0,
42
- stream=True,
43
- )
44
- output = llm(input)
45
- response = output['choices'][0]['text']
46
-
47
- for response_piece in response:
48
- response += response_piece
49
  chatbot[-1] = (chatbot[-1][0], response)
50
 
51
  yield chatbot, history
@@ -63,17 +50,15 @@ def reset_state():
63
 
64
 
65
  with gr.Blocks() as demo:
66
- gr.HTML("""<h1 align="center">01-Yi 6B</h1>""")
67
 
68
  chatbot = gr.Chatbot()
69
  with gr.Row():
70
  with gr.Column(scale=4):
71
- system_prompt = gr.Textbox(show_label=False, placeholder="system prompt ...", lines=2)
72
- user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=6)
73
  submitBtn = gr.Button("Submit", variant="primary")
74
  with gr.Column(scale=1):
75
  max_length = gr.Slider(0, 32048, value=args.max_length, step=1.0, label="Maximum Length", interactive=True)
76
- ctx_length = gr.Slider(0, 4096, value=512, step=1.0, label="Maximum Context Length", interactive=True)
77
  top_p = gr.Slider(0, 1, value=args.top_p, step=0.01, label="Top P", interactive=True)
78
  temperature = gr.Slider(0, 1, value=args.temp, step=0.01, label="Temperature", interactive=True)
79
  emptyBtn = gr.Button("Clear History")
@@ -81,7 +66,7 @@ with gr.Blocks() as demo:
81
  history = gr.State([])
82
 
83
  submitBtn.click(
84
- predict, [user_input, system_prompt, chatbot, max_length, ctx_length, top_p, temperature, history], [chatbot, history], show_progress=True
85
  )
86
  submitBtn.click(reset_user_input, [], [user_input])
87
 
 
10
  parser = argparse.ArgumentParser()
11
  parser.add_argument("-m", "--model", default=DEFAULT_MODEL_PATH, type=Path, help="model path")
12
  parser.add_argument("--mode", default="chat", type=str, choices=["chat", "generate"], help="inference mode")
13
+ parser.add_argument("-l", "--max_length", default=512, type=int, help="max total length including prompt and output")
14
  parser.add_argument("-c", "--max_context_length", default=512, type=int, help="max context length")
15
  parser.add_argument("--top_k", default=0, type=int, help="top-k sampling")
16
  parser.add_argument("--top_p", default=0.7, type=float, help="top-p sampling")
 
25
 
26
 
27
 
28
+ def predict(input, chatbot, max_length, top_p, temperature, history):
29
  chatbot.append((input, ""))
30
  response = ""
31
  history.append(input)
32
 
33
+ for output in llm(input, stream=True, temperature=temperature, top_p=top_p, max_tokens=max_length, ):
34
+ piece = output['choices'][0]['text']
35
+ response += piece
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  chatbot[-1] = (chatbot[-1][0], response)
37
 
38
  yield chatbot, history
 
50
 
51
 
52
  with gr.Blocks() as demo:
53
+ gr.HTML("""<h1 align="center">Yi-6B-GGUF by llama.cpp</h1>""")
54
 
55
  chatbot = gr.Chatbot()
56
  with gr.Row():
57
  with gr.Column(scale=4):
58
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=8)
 
59
  submitBtn = gr.Button("Submit", variant="primary")
60
  with gr.Column(scale=1):
61
  max_length = gr.Slider(0, 32048, value=args.max_length, step=1.0, label="Maximum Length", interactive=True)
 
62
  top_p = gr.Slider(0, 1, value=args.top_p, step=0.01, label="Top P", interactive=True)
63
  temperature = gr.Slider(0, 1, value=args.temp, step=0.01, label="Temperature", interactive=True)
64
  emptyBtn = gr.Button("Clear History")
 
66
  history = gr.State([])
67
 
68
  submitBtn.click(
69
+ predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True
70
  )
71
  submitBtn.click(reset_user_input, [], [user_input])
72