|
import logging |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
from mistral_common.protocol.instruct.messages import ( |
|
FinetuningAssistantMessage, |
|
Roles, |
|
SystemMessage, |
|
ToolMessage, |
|
UserMessage, |
|
) |
|
from mistral_common.protocol.instruct.tool_calls import ( |
|
Function, |
|
FunctionCall, |
|
Tool, |
|
ToolCall, |
|
) |
|
from mistral_common.protocol.instruct.validator import ( |
|
MistralRequestValidatorV3, |
|
ValidationMode, |
|
) |
|
from mistral_common.tokens.instruct.request import InstructRequest |
|
from mistral_common.tokens.tokenizers.base import Tokenizer |
|
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase |
|
|
|
from .exceptions import ( |
|
ConversationFormatError, |
|
FunctionFormatError, |
|
MessageFormatError, |
|
ToolCallFormatError, |
|
UnrecognizedRoleError, |
|
) |
|
|
|
logger = logging.getLogger("tokenize") |
|
|
|
Sequence = List[int] |
|
Mask = List[bool] |
|
|
|
|
|
class TrainingInstructSample(InstructRequest): |
|
available_tools: Optional[List[Tool]] = None |
|
only_last: bool = False |
|
|
|
|
|
@dataclass() |
|
class TokenSample: |
|
tokens: Sequence |
|
masks: Mask |
|
|
|
|
|
class SampleType(str, Enum): |
|
PRETRAIN = "pretrain" |
|
INSTRUCT = "instruct" |
|
|
|
|
|
def encode( |
|
data: Dict[str, Any], |
|
instruct_tokenizer: InstructTokenizerBase, |
|
as_type: SampleType, |
|
) -> TokenSample: |
|
sample: Union[str, TrainingInstructSample] |
|
if as_type == SampleType.PRETRAIN: |
|
sample = get_pretrain_sample(data) |
|
elif as_type == SampleType.INSTRUCT: |
|
sample = build_instruct_sample(data) |
|
|
|
return tokenize(sample=sample, instruct_tokenizer=instruct_tokenizer) |
|
|
|
|
|
def get_pretrain_sample(data: Dict[str, Any]) -> str: |
|
content_keys = ["text", "content"] |
|
assert not all( |
|
k in data for k in content_keys |
|
), "Make sure to have either 'text' or 'content' in your data. Not both." |
|
assert any( |
|
data.get(k) is not None for k in content_keys |
|
), f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}" |
|
|
|
|
|
sample = None |
|
for key in content_keys: |
|
sample = data[key] if key in data else sample |
|
|
|
assert isinstance(sample, str), sample |
|
|
|
return sample |
|
|
|
|
|
def build_instruct_sample(data: Dict[str, Any]) -> TrainingInstructSample: |
|
messages: List[ |
|
SystemMessage | UserMessage | FinetuningAssistantMessage | ToolMessage |
|
] = [] |
|
|
|
available_tools: Optional[List[Tool]] = data.get("available_tools") |
|
system_prompt = data.get("system_prompt") |
|
|
|
messages_keys = ["messages", "interactions"] |
|
content_keys = ["content", "text"] |
|
allowed_roles = [role.value for role in Roles] |
|
|
|
if not any(messages_key in data for messages_key in messages_keys): |
|
err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'." |
|
raise ConversationFormatError(err, str(data)) |
|
|
|
if all(messages_key in data for messages_key in messages_keys): |
|
err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two." |
|
raise ConversationFormatError(err, str(data)) |
|
|
|
|
|
data_messages: Optional[List[Dict[str, Any]]] = None |
|
for key in messages_keys: |
|
data_messages = data[key] if key in data else data_messages |
|
|
|
assert data_messages is not None, "data_messages can't be None" |
|
|
|
if "available_tools" in data and "tools" in data: |
|
err = "The conversation contains both an `available_tools` and `tools` key. You can only have one." |
|
raise ConversationFormatError(err, str(data)) |
|
|
|
if data.get("tools", None) is not None and len(data["tools"]) > 0: |
|
available_tools = _parse_available_tools(data["tools"]) |
|
elif ( |
|
data.get("available_tools", None) is not None |
|
and len(data["available_tools"]) > 0 |
|
): |
|
available_tools = _parse_available_tools(data["available_tools"]) |
|
|
|
for data_message in data_messages: |
|
is_tool_call = data_message.get("tool_calls") is not None |
|
|
|
if "role" not in data_message: |
|
err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'." |
|
raise MessageFormatError(err, str(data)) |
|
|
|
role = data_message["role"] |
|
|
|
if all(key in data_message for key in content_keys): |
|
err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two." |
|
raise MessageFormatError(err, str(data)) |
|
|
|
content: Optional[str] = None |
|
for key in content_keys: |
|
content = content if content is not None else data_message.get(key) |
|
|
|
|
|
if not is_tool_call and content is None: |
|
err = f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}. Make sure that the message includes one of '{content_keys}' keys." |
|
raise MessageFormatError(err, str(data)) |
|
|
|
if role not in allowed_roles: |
|
raise UnrecognizedRoleError(role, allowed_roles) |
|
|
|
if data_message["role"] == "user": |
|
assert content is not None |
|
messages.append(UserMessage(content=content)) |
|
elif data_message["role"] == "assistant": |
|
tool_calls: Optional[List[ToolCall]] = None |
|
|
|
if is_tool_call: |
|
tool_calls = _parse_tool_calls(data_message["tool_calls"]) |
|
|
|
weight = data_message.get("weight") |
|
messages.append( |
|
FinetuningAssistantMessage( |
|
content=content, tool_calls=tool_calls, weight=weight |
|
) |
|
) |
|
elif data_message["role"] == "system": |
|
if system_prompt is not None: |
|
err = "Multiple messages with role 'system' encountered. Only one is allowed." |
|
raise MessageFormatError(err, str(data)) |
|
|
|
system_prompt = content |
|
elif data_message["role"] == "tool": |
|
assert content is not None |
|
tool_message = _parse_tool_message(content, data_message) |
|
messages.append(tool_message) |
|
|
|
|
|
validator = MistralRequestValidatorV3(ValidationMode.finetuning) |
|
validator.validate_messages(messages) |
|
validator._validate_tools(available_tools or []) |
|
|
|
|
|
only_last = data.get("only_last", False) or available_tools is not None |
|
|
|
return TrainingInstructSample( |
|
messages=messages, |
|
system_prompt=system_prompt, |
|
available_tools=available_tools, |
|
only_last=only_last, |
|
) |
|
|
|
|
|
def _parse_available_tools(tools: List[Dict[str, Any]]) -> List[Tool]: |
|
available_tools = [] |
|
for tool in tools: |
|
if "function" not in tool: |
|
raise FunctionFormatError( |
|
"A tool dict does not have a 'function' key.", str(tool) |
|
) |
|
|
|
func_data = tool["function"] |
|
|
|
for key in ["name", "description", "parameters"]: |
|
if key not in func_data: |
|
raise FunctionFormatError( |
|
f"A function dict does not have a {key} key.", str(func_data) |
|
) |
|
|
|
if not isinstance(func_data["parameters"], dict): |
|
raise FunctionFormatError( |
|
f"A function 'parameters' key has to be of type dict, but is {type(func_data['parameters'])}. If the function has no parameters pass an empyt dict ", str(func_data) |
|
) |
|
|
|
description = func_data["description"] |
|
function = Function( |
|
name=func_data["name"], |
|
description=description, |
|
parameters=func_data["parameters"], |
|
) |
|
|
|
available_tools.append(Tool(function=function)) |
|
return available_tools |
|
|
|
|
|
def _parse_tool_calls(calls: List[Dict[str, Any]]) -> List[ToolCall]: |
|
for key in ["id", "function"]: |
|
if not all(key in call for call in calls): |
|
err = f"A tool call of an assistant message does not have a {key} key" |
|
raise ToolCallFormatError(err, str(calls)) |
|
|
|
for key in ["name", "arguments"]: |
|
if not all(key in call["function"] for call in calls): |
|
err = ( |
|
f"A tool call function of an assistant message does not have a {key} key" |
|
) |
|
raise ToolCallFormatError(err, str(calls)) |
|
|
|
if not all(isinstance(call["function"]["arguments"], str) for call in calls): |
|
err = "A tool call function of an assistant message does not have a 'arguments' key of type str" |
|
raise ToolCallFormatError(err, str(calls)) |
|
|
|
tool_calls = [ |
|
ToolCall( |
|
id=call["id"], |
|
function=FunctionCall( |
|
name=call["function"]["name"], |
|
arguments=call["function"]["arguments"], |
|
), |
|
) |
|
for call in calls |
|
] |
|
return tool_calls |
|
|
|
|
|
def _parse_tool_message(content: str, data_message: Dict[str, Any]) -> ToolMessage: |
|
if "tool_call_id" not in data_message: |
|
err = f"A tool message does not contain a 'tool_call_id' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'tool_call_id'." |
|
raise MessageFormatError(err, str(data_message)) |
|
|
|
call_id = data_message["tool_call_id"] |
|
|
|
name = data_message.get("name") |
|
|
|
return ToolMessage(content=content, tool_call_id=call_id, name=name) |
|
|
|
|
|
def tokenize( |
|
sample: Union[str, TrainingInstructSample], |
|
instruct_tokenizer: InstructTokenizerBase, |
|
) -> TokenSample: |
|
if isinstance(sample, str): |
|
tokenizer: Tokenizer = instruct_tokenizer.tokenizer |
|
return tokenize_pretrain(sample, tokenizer) |
|
elif isinstance(sample, TrainingInstructSample): |
|
return tokenize_instruct(sample, instruct_tokenizer) |
|
|
|
raise ValueError( |
|
f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}." |
|
) |
|
|
|
|
|
def tokenize_pretrain(sample: str, tokenizer: Tokenizer) -> TokenSample: |
|
tokens = tokenizer.encode(sample, bos=True, eos=True) |
|
masks = [True] * len(tokens) |
|
return TokenSample(tokens, masks) |
|
|
|
|
|
def tokenize_instruct( |
|
sample: TrainingInstructSample, |
|
instruct_tokenizer: InstructTokenizerBase, |
|
) -> TokenSample: |
|
tokens: List[int] = instruct_tokenizer.start() |
|
masks: List[bool] = [False] |
|
|
|
mask_all_but_last = sample.only_last |
|
|
|
|
|
user_messages = [ |
|
i for i, msg in enumerate(sample.messages) if isinstance(msg, UserMessage) |
|
] |
|
first_user_idx = user_messages[0] if user_messages else -1 |
|
last_user_idx = user_messages[-1] if user_messages else -1 |
|
|
|
for msg_idx, message in enumerate(sample.messages): |
|
if isinstance(message, UserMessage): |
|
curr_tokens = instruct_tokenizer.encode_user_message( |
|
message, |
|
available_tools=sample.available_tools, |
|
is_last=msg_idx == last_user_idx, |
|
is_first=msg_idx == first_user_idx, |
|
system_prompt=sample.system_prompt, |
|
) |
|
curr_masks = [False] * len(curr_tokens) |
|
elif isinstance(message, ToolMessage): |
|
curr_tokens = instruct_tokenizer.encode_tool_message( |
|
message, is_before_last_user_message=msg_idx < last_user_idx |
|
) |
|
curr_masks = [False] * len(curr_tokens) |
|
elif isinstance(message, FinetuningAssistantMessage): |
|
is_last_message = msg_idx == (len(sample.messages) - 1) |
|
|
|
|
|
message = maybe_remove_call_id(message, is_last_message=is_last_message) |
|
|
|
curr_tokens = instruct_tokenizer.encode_assistant_message( |
|
message, is_before_last_user_message=False |
|
) |
|
|
|
is_weighted = message.weight is None or message.weight == 1 |
|
is_relevant = (not mask_all_but_last) or is_last_message |
|
if is_weighted and is_relevant: |
|
curr_masks = [True] * len(curr_tokens) |
|
else: |
|
|
|
curr_masks = [False] * len(curr_tokens) |
|
|
|
tokens.extend(curr_tokens) |
|
masks.extend(curr_masks) |
|
|
|
return TokenSample(tokens, masks) |
|
|
|
|
|
def maybe_remove_call_id(message: FinetuningAssistantMessage, is_last_message: bool): |
|
if message.tool_calls is None or not is_last_message: |
|
return message |
|
|
|
|
|
message.tool_calls = [ |
|
ToolCall(function=call.function) for call in message.tool_calls |
|
] |
|
|
|
return message |
|
|