pandora-s's picture
To Chat Interface
971f149 verified
raw
history blame
2.25 kB
import gradio as gr
from gradio.data_classes import FileData
from huggingface_hub import snapshot_download
from pathlib import Path
import base64
import spaces
import os
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, TextChunk, ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
models_path = Path.home().joinpath('pixtral', 'Pixtral')
models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
local_dir=models_path)
tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json")
model = Transformer.from_folder(models_path)
def image_to_base64(image_path):
with open(image_path, 'rb') as img:
encoded_string = base64.b64encode(img.read()).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_string}"
@spaces.GPU(duration=30)
def run_inference(message, history):
print(message)
print(history)
## to be fixed
messages = []
for couple in history:
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in couple[0][0]]+[TextChunk(text=couple[0][1])]))
messages.append(AssistantMessage(content = couple[1]))
##
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in message["files"]]+[TextChunk(text=message["text"])]))
completion_request = ChatCompletionRequest(messages=messages)
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return result
demo = gr.ChatInterface(fn=run_inference, title="Pixtral 12B", multimodal=True)
demo.queue().launch()