VLRM

This repository contains the weights of BLIP-2 OPT-2.7B model fine-tuned by reinforcement learning method introduced in the paper VLRM: Vision-Language Models act as Reward Models for Image Captioning.

The RL-tuned model is able to generate longer and more comprehensive descriptions with zero computational overhead compared to the original model.

You can find other details in the GitHub Repository (to be done).

Running the model

Option 1

Load the whole model from this repo
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration

processor = Blip2Processor.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")

img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)

out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman in a plaid shirt shaking hands with a yellow labrador retriever sitting on the ground at sunset on a beach in florida'

Option 2

Since the fine-tuned layers take small part of the whole model, you can first load the original model, then load the RL-tuned weights.

Step 1. Load the original model
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")

img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)

out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman sitting on the beach with a dog'
Step 2. Load the RL-tuned weights Available checkpoints:
  • vlrm-blip2-opt-2.7b.pt (VLRM in the paper)
  • vlrm-rs-blip2-opt-2.7b.pt (VLRM-RS in the paper)
from huggingface_hub import hf_hub_download
finetuned_weights_state_dict = torch.load(hf_hub_download(repo_id="sashakunitsyn/vlrm-blip2-opt-2.7b", filename="vlrm-blip2-opt-2.7b.pt"))
model.load_state_dict(finetuned_weights_state_dict, strict=False)

out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman in a plaid shirt shaking hands with a yellow labrador retriever sitting on the ground at sunset on a beach in florida'
Downloads last month
166
Safetensors
Model size
3.74B params
Tensor type
FP16
ยท
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for sashakunitsyn/vlrm-blip2-opt-2.7b

Finetuned
(2)
this model

Spaces using sashakunitsyn/vlrm-blip2-opt-2.7b 2