TinyChart-3B / tinychart /eval /run_tiny_chart.py
xzl12306's picture
first commit
d6bc023
import argparse
import torch
from tinychart.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
IMAGE_PLACEHOLDER,
)
from tinychart.conversation import conv_templates, SeparatorStyle
from tinychart.model.builder import load_pretrained_model
from tinychart.utils import disable_torch_init
from tinychart.mm_utils import (
process_images,
tokenizer_image_token,
get_model_name_from_path,
KeywordsStoppingCriteria,
)
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
import re
def image_parser(args):
out = args.image_file.split(args.sep)
return out
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def load_images(image_files):
out = []
for image_file in image_files:
image = load_image(image_file)
out.append(image)
return out
def inference_model(image_files, query, model, tokenizer, image_processor, context_len, conv_mode, temperature=0, max_new_tokens=100):
qs = query
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
if model.config.mm_use_im_start_end:
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
else:
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
if model.config.mm_use_im_start_end:
qs = image_token_se + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
images = load_images(image_files)
images_tensor = process_images(
images,
image_processor,
model.config
).to(model.device, dtype=torch.float16)
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=images_tensor,
do_sample=True if temperature > 0 else False,
temperature=temperature,
# top_p=top_p,
# num_beams=args.num_beams,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=max_new_tokens,
use_cache=True,
stopping_criteria=[stopping_criteria],
)
outputs = tokenizer.batch_decode(
output_ids, skip_special_tokens=True
)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
print(outputs)
return outputs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--query", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--sep", type=str, default=",")
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--max_new_tokens", type=int, default=512)
args = parser.parse_args()
inference_model(args)