# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import base64 import dataclasses from enum import Enum, auto from io import BytesIO from typing import Any, List, Tuple, Union from PIL import Image class SeparatorStyle(Enum): """Different separator style.""" TWO = auto() PLAIN = auto() CHATML = auto() LLAMA_2 = auto() LLAMA_3 = auto() QWEN2 = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.PLAIN sep: str = "###" sep2: str = None version: str = "Unknown" tokenizer_id: str = "" tokenizer: Any = None # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None # Stops generation if meeting any token in this list stop_token_ids: List[int] = None skip_next: bool = False def get_prompt(self): messages = self.messages if self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message = message[0] ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.LLAMA_3: wrap_sys = ( lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg ) ret = "<|begin_of_text|>" + wrap_sys(self.system) for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message = message[0] ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" ret += message.strip() + self.sep2 else: ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" return ret elif self.sep_style == SeparatorStyle.LLAMA_2: wrap_sys = ( lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg ) wrap_inst = lambda msg: f"[INST] {msg} [/INST]" ret = "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i == 0: message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message else: ret += " " + message + " " + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.PLAIN: seps = [self.sep, self.sep2] ret = self.system for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += message + seps[i % 2] else: ret += "" elif self.sep_style == SeparatorStyle.CHATML: ret = "" if self.system == "" else self.system + self.sep + "\n" for role, message in messages: if message: if type(message) is tuple: message, images = message message = "" * len(images) + message ret += role + "\n" + message + self.sep + "\n" else: ret += role + "\n" return ret elif self.sep_style == SeparatorStyle.QWEN2: start = "<|im_start|>" end = "<|im_end|>\n" ret = start + "system\n" + self.system + end for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message if message.endswith("<|endoftext|>"): message = message.replace("<|endoftext|>", "") ret += start + role + "\n" + message + end + "<|endoftext|>" else: assert ( not "<|endoftext|>" in message ), f"Invalid message: {message}" ret += start + role + "\n" + message + end else: ret += start + role + "\n" else: raise ValueError(f"Invalid style: {self.sep_style}") return ret def append_message(self, role, message): self.messages.append([role, message]) def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: if type(msg) is tuple: msg, speech = msg ret.append([msg, None]) else: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version, ) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [ [x, y[0] if type(y) is tuple else y] for x, y in self.messages ], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_vicuna_v1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=[], offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llama_2 = Conversation( system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("USER", "ASSISTANT"), version="llama_v2", messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_llama_3 = Conversation( system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("user", "assistant"), version="llama_v3", messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_3, sep="", sep2="<|eot_id|>", ) conv_qwen_v1 = Conversation( system="You are a helpful assistant.", roles=("user", "assistant"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.QWEN2, ) conv_plain = Conversation( system="", roles=("", ""), messages=(), offset=0, sep_style=SeparatorStyle.PLAIN, sep="", ) conv_qwen = Conversation( system="""<|im_start|>system You are a helpful assistant.""", roles=("<|im_start|>user", "<|im_start|>assistant"), version="qwen", messages=[], offset=0, sep_style=SeparatorStyle.CHATML, sep="<|im_end|>", ) default_conversation = conv_llama_3 conv_templates = { "v1": conv_vicuna_v1, "plain": conv_plain, "llama_2": conv_llama_2, "llama_3": conv_llama_3, "v1_qwen2": conv_qwen_v1, "qwen_1_5": conv_qwen, } if __name__ == "__main__": print(default_conversation.get_prompt())