pandora-s commited on
Commit
971f149
·
verified ·
1 Parent(s): 8831e89

To Chat Interface

Browse files

Quick PR still not finished to make it a chat interface instead! Almost done, just history logic to be done, will do later 👍

Files changed (1) hide show
  1. app.py +17 -25
app.py CHANGED
@@ -10,7 +10,7 @@ from mistral_inference.transformer import Transformer
10
  from mistral_inference.generate import generate
11
 
12
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
13
- from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk
14
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
15
 
16
  models_path = Path.home().joinpath('pixtral', 'Pixtral')
@@ -29,9 +29,20 @@ def image_to_base64(image_path):
29
  return f"data:image/jpeg;base64,{encoded_string}"
30
 
31
  @spaces.GPU(duration=30)
32
- def run_inference(image_url, prompt):
33
- base64 = image_to_base64(image_url)
34
- completion_request = ChatCompletionRequest(messages=[UserMessage(content=[ImageURLChunk(image_url=base64), TextChunk(text=prompt)])])
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  encoded = tokenizer.encode_chat_completion(completion_request)
37
 
@@ -40,26 +51,7 @@ def run_inference(image_url, prompt):
40
 
41
  out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
42
  result = tokenizer.decode(out_tokens[0])
43
- return [[prompt, result]]
44
-
45
- with gr.Blocks() as demo:
46
- with gr.Row():
47
- image_box = gr.Image(type="filepath")
48
-
49
- chatbot = gr.Chatbot(
50
- scale = 2,
51
- height=750
52
- )
53
- text_box = gr.Textbox(
54
- placeholder="Enter your text and press enter, or upload an image.",
55
- container=False,
56
- )
57
-
58
-
59
- btn = gr.Button("Submit")
60
- clicked = btn.click(run_inference,
61
- [image_box,text_box],
62
- chatbot
63
- )
64
 
 
65
  demo.queue().launch()
 
10
  from mistral_inference.generate import generate
11
 
12
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
13
+ from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, TextChunk, ImageURLChunk
14
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
15
 
16
  models_path = Path.home().joinpath('pixtral', 'Pixtral')
 
29
  return f"data:image/jpeg;base64,{encoded_string}"
30
 
31
  @spaces.GPU(duration=30)
32
+ def run_inference(message, history):
33
+ print(message)
34
+ print(history)
35
+
36
+ ## to be fixed
37
+ messages = []
38
+ for couple in history:
39
+ messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in couple[0][0]]+[TextChunk(text=couple[0][1])]))
40
+ messages.append(AssistantMessage(content = couple[1]))
41
+ ##
42
+
43
+ messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in message["files"]]+[TextChunk(text=message["text"])]))
44
+
45
+ completion_request = ChatCompletionRequest(messages=messages)
46
 
47
  encoded = tokenizer.encode_chat_completion(completion_request)
48
 
 
51
 
52
  out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
53
  result = tokenizer.decode(out_tokens[0])
54
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ demo = gr.ChatInterface(fn=run_inference, title="Pixtral 12B", multimodal=True)
57
  demo.queue().launch()