Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import importlib | |
from typing import Dict, Optional, Sequence, List | |
import transformers | |
from tinychart.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from tinychart import conversation as conversation_lib | |
from tinychart.arguments import * | |
PREPROCESS_REGISTRY = {} | |
def register_preprocess(name): | |
def register_preprocess_cls(cls): | |
if name in PREPROCESS_REGISTRY: | |
return PREPROCESS_REGISTRY[name] | |
PREPROCESS_REGISTRY[name] = cls | |
return cls | |
return register_preprocess_cls | |
def import_modules(modules_dir, namespace): | |
for file in os.listdir(modules_dir): | |
path = os.path.join(modules_dir, file) | |
if ( | |
not file.startswith("_") | |
and not file.startswith(".") | |
and (file.endswith(".py") or os.path.isdir(path)) | |
): | |
module_name = file[: file.find(".py")] if file.endswith(".py") else file | |
importlib.import_module(namespace + "." + module_name) | |
models_dir = os.path.join(os.path.dirname(__file__), 'preprocess') | |
import_modules(models_dir, "tinychart.data.preprocess") | |
def PreprocessSelect(version): | |
result = PREPROCESS_REGISTRY.get(version, None) | |
if result is None: | |
for name in PREPROCESS_REGISTRY.keys(): | |
if version in name: | |
result = PREPROCESS_REGISTRY[name] | |
break | |
if result is None: | |
result = PREPROCESS_REGISTRY['default'] | |
return result | |
def preprocess_multimodal( | |
sources: Sequence[str], | |
data_args: DataArguments | |
) -> Dict: | |
is_multimodal = data_args.is_multimodal | |
if not is_multimodal: | |
return sources | |
for source in sources: | |
for sentence in source: | |
if DEFAULT_IMAGE_TOKEN in sentence['value']: | |
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() | |
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] | |
sentence['value'] = sentence['value'].strip() | |
if "mmtag" in conversation_lib.default_conversation.version: | |
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, | |
'<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>') | |
replace_token = DEFAULT_IMAGE_TOKEN | |
if data_args.mm_use_im_start_end: | |
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN | |
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) | |
return sources | |
def preprocess( | |
sources: Sequence[str], | |
tokenizer: transformers.PreTrainedTokenizer, | |
has_image: bool = False | |
) -> Dict: | |
return PreprocessSelect(conversation_lib.default_conversation.version)(sources, tokenizer, has_image) | |