Spaces:
Running
on
Zero
Running
on
Zero
import time | |
from typing import Dict, Optional, Sequence, List | |
import copy | |
import transformers | |
import tokenizers | |
import torch | |
from tinychart.data.process import register_preprocess | |
from tinychart.mm_utils import tokenizer_image_token | |
from tinychart import conversation as conversation_lib | |
from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \ | |
DEFAULT_IM_END_TOKEN | |
from packaging import version | |
# IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') | |
def preprocess_v1( | |
sources, | |
tokenizer: transformers.PreTrainedTokenizer, | |
has_image: bool = False | |
) -> Dict: | |
# conv = conversation_lib.default_conversation.copy() | |
conv = conversation_lib.conv_phi_v0.copy() | |
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | |
# Apply prompt templates | |
conversations = [] | |
for i, source in enumerate(sources): | |
if roles[source[0]["from"]] != conv.roles[0]: | |
# Skip the first one if it is not from human | |
source = source[1:] | |
conv.messages = [] | |
for j, sentence in enumerate(source): | |
role = roles[sentence["from"]] | |
assert role == conv.roles[j % 2], f"{i}" | |
conv.append_message(role, sentence["value"]) | |
conversations.append(conv.get_prompt()) | |
# Tokenize conversations | |
if has_image: | |
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) | |
else: | |
input_ids = tokenizer( | |
conversations, | |
return_tensors="pt", | |
padding="longest", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
).input_ids | |
targets = input_ids.clone() | |
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO | |
# Mask targets | |
sep = conv.sep + conv.roles[1] + ": " | |
for conversation, target in zip(conversations, targets): | |
total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
# total_len = len(target) | |
rounds = conversation.split(conv.sep2) | |
cur_len = 0 | |
# cur_len = 1 | |
# cur_len = 1 + 1 | |
target[:cur_len] = IGNORE_INDEX | |
for i, rou in enumerate(rounds): | |
if rou == "": | |
break | |
parts = rou.split(sep) | |
if len(parts) != 2: | |
break | |
parts[0] += sep | |
if has_image: | |
round_len = len(tokenizer_image_token(rou, tokenizer)) | |
# round_len = len(tokenizer_image_token(rou, tokenizer)) - 2 + 1 | |
# instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 | |
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) | |
else: | |
round_len = len(tokenizer(rou).input_ids) | |
# round_len = len(tokenizer(rou).input_ids) - 2 + 1 | |
# instruction_len = len(tokenizer(parts[0]).input_ids) - 2 | |
instruction_len = len(tokenizer(parts[0]).input_ids) | |
# if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: | |
# round_len -= 1 | |
# instruction_len -= 1 | |
instruction_len -= 1 | |
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX | |
cur_len += round_len | |
# target[cur_len:] = IGNORE_INDEX | |
# import pdb;pdb.set_trace() | |
if cur_len < tokenizer.model_max_length: | |
if cur_len != total_len: | |
print( | |
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." | |
f" (ignored)" | |
) | |
print("number of rounds: ", len(rounds) - 1) | |
print("rounds: ", rounds[:-1]) | |
print("conversation: ", conversations) | |
print(target) | |
print(input_ids) | |
time.sleep(5) | |
target[:] = IGNORE_INDEX | |
return dict( | |
input_ids=input_ids, | |
labels=targets, | |
) | |