xzl12306's picture
first commit
d6bc023
raw
history blame
4.11 kB
import time
from typing import Dict, Optional, Sequence, List
import copy
import transformers
import tokenizers
import torch
from tinychart.data.process import register_preprocess
from tinychart.mm_utils import tokenizer_image_token
from tinychart import conversation as conversation_lib
from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
DEFAULT_IM_END_TOKEN
from packaging import version
# IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
@register_preprocess('v1')
def preprocess_v1(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
# conv = conversation_lib.default_conversation.copy()
conv = conversation_lib.conv_phi_v0.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
# total_len = len(target)
rounds = conversation.split(conv.sep2)
cur_len = 0
# cur_len = 1
# cur_len = 1 + 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
# round_len = len(tokenizer_image_token(rou, tokenizer)) - 2 + 1
# instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
else:
round_len = len(tokenizer(rou).input_ids)
# round_len = len(tokenizer(rou).input_ids) - 2 + 1
# instruction_len = len(tokenizer(parts[0]).input_ids) - 2
instruction_len = len(tokenizer(parts[0]).input_ids)
# if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
# round_len -= 1
# instruction_len -= 1
instruction_len -= 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
# target[cur_len:] = IGNORE_INDEX
# import pdb;pdb.set_trace()
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
print("number of rounds: ", len(rounds) - 1)
print("rounds: ", rounds[:-1])
print("conversation: ", conversations)
print(target)
print(input_ids)
time.sleep(5)
target[:] = IGNORE_INDEX
return dict(
input_ids=input_ids,
labels=targets,
)