import torch
from transformers import TextStreamer
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation"))
from llava.constants import IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates, SeparatorStyle
from llava.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, tokenizer_image_token
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
import shutil
cur_dir = os.path.dirname(os.path.abspath(__file__))
title_markdown = ("""
VLM-RLAIF: Tuning Large Multimodal Models for Videos using Reinforcement Learning from AI Feedback (ACL 2024 Oral)
If you like our project, please give us a star ✨ on Github for the latest update.
""")
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
tos_markdown = ("""""")
learn_more_markdown = ("""
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA.
""")
class Chat:
def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', cache_dir=None):
disable_torch_init()
model_name = get_model_name_from_path(model_path)
is_rlhf_checkpoint = 'rlhf' in model_path.lower()
print("MODEL_PATH", model_path)
print("RLHF Checkpoint: ", is_rlhf_checkpoint)
if not model_base or model_base == "none": model_base = None
if is_rlhf_checkpoint:
model_name = model_path
print("Config?", os.path.exists(os.path.join(model_path, "config.json")))
if not os.path.exists(os.path.join(model_path, "config.json")):
print("Copying")
shutil.copy(os.path.join(model_base, "config.json"), os.path.join(model_path, "config.json")) # Copy SFT model's config -> to RLHF folder
print("Listed", os.listdir(model_path))
print("Copying done")
self.tokenizer, self.model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, False, False, device=device)
self.image_processor = image_processor
self.conv_mode = conv_mode
self.conv = conv_templates[conv_mode].copy()
self.device = self.model.device
print(self.model)
def get_prompt(self, qs, state):
state.append_message(state.roles[0], qs)
state.append_message(state.roles[1], None)
return state
def _get_latest_prompt(self, state):
new_state = state.copy()
new_state.messages = state.messages[-2:]
return new_state
@torch.inference_mode()
# def generate(self, images_tensor: list, prompt: str, first_run: bool, state):
def generate(self, images_tensor: torch.Tensor, prompt: str, first_run: bool, state):
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
state = self.get_prompt(prompt, state)
# prompt = state.get_prompt()
latest_state = self._get_latest_prompt(state)
prompt = latest_state.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
temperature = 0.2
max_new_tokens = 1024
stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
print(prompt, input_ids.shape, images_tensor.shape)
# print(images_tensor)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=images_tensor,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria])
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
outputs = outputs.replace("QA_GT_caption_based_noisy", "")
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
print('response', outputs)
return outputs, state