Hristo ZHANG 张鹤立 commited on
Commit
a023bcb
·
1 Parent(s): baa3fe5
Files changed (2) hide show
  1. app.py +76 -14
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,25 +1,87 @@
1
  import gradio as gr
2
- from llama_cpp import Llama
3
  import os
 
 
4
  model_file = "Yi-6B.q4_k_m.gguf"
5
  if not os.path.isfile("Yi-6B.q4_k_m.gguf"):
6
  os.system("wget -c https://huggingface.co/SamPurkis/Yi-6B-GGUF/resolve/main/Yi-6B.q4_k_m.gguf")
7
 
8
- llm = Llama(model_path=model_file)
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def generate_text(input_text):
11
- output = llm(f"Human: {input_text} A:", max_tokens=512, stop=["Assistant:", "\n"], echo=True)
12
- return output['choices'][0]['text']
13
 
14
- input_text = gr.inputs.Textbox(lines= 10, label="Enter your input text")
15
- output_text = gr.outputs.Textbox(label="Output text")
16
 
17
- description = "llama.cpp implementation in python [https://github.com/abetlen/llama-cpp-python]"
18
 
19
- examples = [
20
- ["What is the capital of France? ", "The capital of France is Paris."],
21
- ["Who wrote the novel 'Pride and Prejudice'?", "The novel 'Pride and Prejudice' was written by Jane Austen."],
22
- ["What is the square root of 64?", "The square root of 64 is 8."]
23
- ]
24
 
25
- gr.Interface(fn=generate_text, inputs=input_text, outputs=output_text, title="Llama Language Model", description=description, examples=examples).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import os
3
+ from pathlib import Path
4
+ import argparse
5
  model_file = "Yi-6B.q4_k_m.gguf"
6
  if not os.path.isfile("Yi-6B.q4_k_m.gguf"):
7
  os.system("wget -c https://huggingface.co/SamPurkis/Yi-6B-GGUF/resolve/main/Yi-6B.q4_k_m.gguf")
8
 
9
+ DEFAULT_MODEL_PATH = model_file
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("-m", "--model", default=DEFAULT_MODEL_PATH, type=Path, help="model path")
12
+ parser.add_argument("--mode", default="chat", type=str, choices=["chat", "generate"], help="inference mode")
13
+ parser.add_argument("-l", "--max_length", default=2048, type=int, help="max total length including prompt and output")
14
+ parser.add_argument("-c", "--max_context_length", default=512, type=int, help="max context length")
15
+ parser.add_argument("--top_k", default=0, type=int, help="top-k sampling")
16
+ parser.add_argument("--top_p", default=0.7, type=float, help="top-p sampling")
17
+ parser.add_argument("--temp", default=0.95, type=float, help="temperature")
18
+ parser.add_argument("--repeat_penalty", default=1.1, type=float, help="penalize repeat sequence of tokens")
19
+ parser.add_argument("-t", "--threads", default=0, type=int, help="number of threads for inference")
20
+ parser.add_argument("--plain", action="store_true", help="display in plain text without markdown support")
21
+ args = parser.parse_args()
22
 
23
+ from pyllamacpp.model import Model
24
+ model = Model(model_path=model_file)
 
25
 
 
 
26
 
 
27
 
28
+ def predict(input, system_prompt, chatbot, max_length, ctx_length, top_p, temperature, history):
29
+ chatbot.append((input, ""))
30
+ response = ""
31
+ history.append(input)
 
32
 
33
+ generation_kwargs = dict(
34
+ max_length=max_length,
35
+ max_context_length=ctx_length,
36
+ do_sample=temperature > 0,
37
+ top_k=40,
38
+ top_p=top_p,
39
+ temperature=temperature,
40
+ repetition_penalty=1.1,
41
+ num_threads=0,
42
+ stream=True,
43
+ )
44
+ for response_piece in model.generate(input):
45
+ response += response_piece
46
+ chatbot[-1] = (chatbot[-1][0], response)
47
+
48
+ yield chatbot, history
49
+
50
+ history.append(response)
51
+ yield chatbot, history
52
+
53
+
54
+ def reset_user_input():
55
+ return gr.update(value="")
56
+
57
+
58
+ def reset_state():
59
+ return [], []
60
+
61
+
62
+ with gr.Blocks() as demo:
63
+ gr.HTML("""<h1 align="center">01-Yi 6B</h1>""")
64
+
65
+ chatbot = gr.Chatbot()
66
+ with gr.Row():
67
+ with gr.Column(scale=4):
68
+ system_prompt = gr.Textbox(show_label=False, placeholder="system prompt ...", lines=2)
69
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=6)
70
+ submitBtn = gr.Button("Submit", variant="primary")
71
+ with gr.Column(scale=1):
72
+ max_length = gr.Slider(0, 32048, value=args.max_length, step=1.0, label="Maximum Length", interactive=True)
73
+ ctx_length = gr.Slider(0, 4096, value=512, step=1.0, label="Maximum Context Length", interactive=True)
74
+ top_p = gr.Slider(0, 1, value=args.top_p, step=0.01, label="Top P", interactive=True)
75
+ temperature = gr.Slider(0, 1, value=args.temp, step=0.01, label="Temperature", interactive=True)
76
+ emptyBtn = gr.Button("Clear History")
77
+
78
+ history = gr.State([])
79
+
80
+ submitBtn.click(
81
+ predict, [user_input, system_prompt, chatbot, max_length, ctx_length, top_p, temperature, history], [chatbot, history], show_progress=True
82
+ )
83
+ submitBtn.click(reset_user_input, [], [user_input])
84
+
85
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
86
+
87
+ demo.queue().launch(share=False, inbrowser=True)
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  gradio
2
- llama-cpp-python
 
1
  gradio
2
+ git+https://github.com/zhangheli/pyllamacpp.git@main