vilarin commited on
Commit
057b685
·
verified ·
1 Parent(s): b3599a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -40
app.py CHANGED
@@ -3,7 +3,7 @@ import time
3
  import spaces
4
  import torch
5
  import gradio as gr
6
- from threading import Thread
7
 
8
  from huggingface_hub import snapshot_download
9
  from pathlib import Path
@@ -11,9 +11,11 @@ from pathlib import Path
11
  from mistral_inference.transformer import Transformer
12
  from mistral_inference.generate import generate
13
 
 
14
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
15
  from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
16
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
 
17
 
18
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
 
@@ -45,18 +47,20 @@ snapshot_download(repo_id="mistralai/Ministral-8B-Instruct-2410", allow_patterns
45
  # tokenizer
46
  device = "cuda" if torch.cuda.is_available() else "cpu" # for GPU usage or "cpu" for CPU usage
47
  tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
 
 
48
  model = Transformer.from_folder(
49
  mistral_models_path,
50
  device=device,
51
  dtype=torch.bfloat16)
52
-
53
-
54
  @spaces.GPU()
55
  def stream_chat(
56
  message: str,
57
- history: list,
 
58
  temperature: float = 0.3,
59
- max_new_tokens: int = 1024,
60
  ):
61
  print(f'message: {message}')
62
  print(f'history: {history}')
@@ -75,15 +79,27 @@ def stream_chat(
75
  conversation.append(UserMessage(content=message))
76
 
77
  print(f'history: {conversation}')
 
 
 
 
78
 
79
- completion_request = ChatCompletionRequest(messages=conversation)
 
 
 
 
 
 
 
 
80
 
81
  tokens = tokenizer.encode_chat_completion(completion_request).tokens
82
 
83
  out_tokens, _ = generate(
84
  [tokens],
85
  model,
86
- max_tokens=max_new_tokens,
87
  temperature=temperature,
88
  eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
89
 
@@ -93,44 +109,72 @@ def stream_chat(
93
  time.sleep(0.05)
94
  yield result[: i + 1]
95
 
96
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  with gr.Blocks(theme="citrus", css=CSS) as demo:
99
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
100
- gr.ChatInterface(
101
- fn=stream_chat,
102
- title="Mistral-lab",
103
- chatbot=chatbot,
104
- # type="messages",
105
- fill_height=True,
106
  examples=[
107
- ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
108
- ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
109
- ["Tell me a random fun fact about the Roman Empire."],
110
- ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
111
- ],
112
- cache_examples = False,
113
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
114
- additional_inputs=[
115
- gr.Slider(
116
- minimum=0,
117
- maximum=1,
118
- step=0.1,
119
- value=0.3,
120
- label="Temperature",
121
- render=False,
122
- ),
123
- gr.Slider(
124
- minimum=128,
125
- maximum=8192,
126
- step=1,
127
- value=1024,
128
- label="Max new tokens",
129
- render=False,
130
- ),
131
- ],
132
  )
133
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if __name__ == "__main__":
136
  demo.launch()
 
3
  import spaces
4
  import torch
5
  import gradio as gr
6
+ import json
7
 
8
  from huggingface_hub import snapshot_download
9
  from pathlib import Path
 
11
  from mistral_inference.transformer import Transformer
12
  from mistral_inference.generate import generate
13
 
14
+ from mistral_common.protocol.instruct.tool_calls import Function, Tool
15
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
16
  from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
17
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
18
+ from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy
19
 
20
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
21
 
 
47
  # tokenizer
48
  device = "cuda" if torch.cuda.is_available() else "cpu" # for GPU usage or "cpu" for CPU usage
49
  tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
50
+ tekken = tokenizer.instruct_tokenizer.tokenizer
51
+ tekken.special_token_policy = SpecialTokenPolicy.IGNORE
52
  model = Transformer.from_folder(
53
  mistral_models_path,
54
  device=device,
55
  dtype=torch.bfloat16)
56
+
 
57
  @spaces.GPU()
58
  def stream_chat(
59
  message: str,
60
+ history: list,
61
+ tools: str,
62
  temperature: float = 0.3,
63
+ max_tokens: int = 1024,
64
  ):
65
  print(f'message: {message}')
66
  print(f'history: {history}')
 
79
  conversation.append(UserMessage(content=message))
80
 
81
  print(f'history: {conversation}')
82
+
83
+ local_namespace = {}
84
+ exec(tools, globals(), local_namespace)
85
+ function_params = local_namespace.get('function_params', {})
86
 
87
+ completion_request = ChatCompletionRequest(
88
+ tools=[
89
+ Tool(
90
+ function=Function(
91
+ **function_params
92
+ )
93
+ )
94
+ ] if tools else None,
95
+ messages=conversation)
96
 
97
  tokens = tokenizer.encode_chat_completion(completion_request).tokens
98
 
99
  out_tokens, _ = generate(
100
  [tokens],
101
  model,
102
+ max_tokens=max_tokens,
103
  temperature=temperature,
104
  eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
105
 
 
109
  time.sleep(0.05)
110
  yield result[: i + 1]
111
 
112
+
113
+ tools_schema = """function_params = {
114
+ name="get_current_weather",
115
+ description="Get the current weather",
116
+ parameters={
117
+ "type": "object",
118
+ "properties": {
119
+ "location": {
120
+ "type": "string",
121
+ "description": "The city and state, e.g. San Francisco, CA",
122
+ },
123
+ "format": {
124
+ "type": "string",
125
+ "enum": ["celsius", "fahrenheit"],
126
+ "description": "The temperature unit to use. Infer this from the users location.",
127
+ },
128
+ },
129
+ "required": ["location", "format"],
130
+ },
131
+ }"""
132
 
133
  with gr.Blocks(theme="citrus", css=CSS) as demo:
134
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
135
+ chatbot = gr.Chatbot(
136
+ height=600,
137
+ placeholder=PLACEHOLDER,
 
 
 
138
  examples=[
139
+ {'text': "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."},
140
+ {'text': "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."},
141
+ {'text': "Tell me a random fun fact about the Roman Empire."},
142
+ {'text': "Show me a code snippet of a website's sticky header in CSS and JavaScript."},
143
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
+ msg = gr.Textbox(
146
+ value = "",
147
+ label = "Chat"
148
+ )
149
+ with gr.Row():
150
+ send = gr.Button(
151
+ value = "Send",
152
+ size="lg",
153
+ variant = "primary",
154
+ )
155
+ clear = gr.ClearButton([msg, chatbot])
156
+
157
+ with gr.Accordion(label="⚙️ Parameters", open=True,):
158
+ tools = gr.Textbox(
159
+ value = tools_schema,
160
+ lable = "Tools schema",
161
+ )
162
+ temperature = gr.Slider(
163
+ minimum=0,
164
+ maximum=1,
165
+ step=0.1,
166
+ value=0.3,
167
+ label="Temperature",
168
+ ),
169
+ max_tokens = gr.Slider(
170
+ minimum=128,
171
+ maximum=8192,
172
+ step=1,
173
+ value=1024,
174
+ label="Max new tokens",
175
+ )
176
+ msg.submit(fn = stream_chat, inputs = [msg, chatbot, tools, temperature, max_tokens], outputs = [chatbot])
177
+ send.click(fn = stream_chat, inputs = [msg, chatbot, tools, temperature, max_tokens], outputs = [chatbot])
178
 
179
  if __name__ == "__main__":
180
  demo.launch()