import torch from transformers import TextStreamer import os import sys sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation")) from llava.constants import IMAGE_TOKEN_INDEX from llava.conversation import conv_templates, SeparatorStyle from llava.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, tokenizer_image_token from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init import shutil cur_dir = os.path.dirname(os.path.abspath(__file__)) title_markdown = ("""

VLM-RLAIF: Tuning Large Multimodal Models for Videos using Reinforcement Learning from AI Feedback (ACL 2024 Oral)

If you like our project, please give us a star ✨ on Github for the latest update.
""") block_css = """ #buttons button { min-width: min(120px,100%); } """ tos_markdown = ("""""") learn_more_markdown = (""" ### License The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA. """) class Chat: def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', cache_dir=None): disable_torch_init() model_name = get_model_name_from_path(model_path) is_rlhf_checkpoint = 'rlhf' in model_path.lower() print("MODEL_PATH", model_path) print("RLHF Checkpoint: ", is_rlhf_checkpoint) if not model_base or model_base == "none": model_base = None if is_rlhf_checkpoint: model_name = model_path print("Config?", os.path.exists(os.path.join(model_path, "config.json"))) if not os.path.exists(os.path.join(model_path, "config.json")): print("Copying") shutil.copy(os.path.join(model_base, "config.json"), os.path.join(model_path, "config.json")) # Copy SFT model's config -> to RLHF folder print("Listed", os.listdir(model_path)) print("Copying done") self.tokenizer, self.model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, False, False, device=device) self.image_processor = image_processor self.conv_mode = conv_mode self.conv = conv_templates[conv_mode].copy() self.device = self.model.device print(self.model) def get_prompt(self, qs, state): state.append_message(state.roles[0], qs) state.append_message(state.roles[1], None) return state def _get_latest_prompt(self, state): new_state = state.copy() new_state.messages = state.messages[-2:] return new_state @torch.inference_mode() # def generate(self, images_tensor: list, prompt: str, first_run: bool, state): def generate(self, images_tensor: torch.Tensor, prompt: str, first_run: bool, state): tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor state = self.get_prompt(prompt, state) # prompt = state.get_prompt() latest_state = self._get_latest_prompt(state) prompt = latest_state.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) temperature = 0.2 max_new_tokens = 1024 stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) print(prompt, input_ids.shape, images_tensor.shape) # print(images_tensor) with torch.inference_mode(): output_ids = model.generate( input_ids, images=images_tensor, do_sample=True, temperature=temperature, max_new_tokens=max_new_tokens, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() outputs = outputs.replace("QA_GT_caption_based_noisy", "") if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() print('response', outputs) return outputs, state