pranavSIT's picture
added pali inference
74e8f2f
import os
import io
import jax
import base64
import warnings
import functools
import numpy as np
import sentencepiece
import ml_collections
from PIL import Image
import big_vision.utils
import tensorflow as tf
import supervision as sv
import big_vision.sharding
from typing import Tuple, List, Optional
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
SEQLEN = 128
class PaliGemmaManager:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(PaliGemmaManager, cls).__new__(cls)
return cls._instance
def __init__(self, model, params, tokenizer):
self.model = model
self.params = params
self.tokenizer = tokenizer
self.decode_fn = None
self.decode = None
self.mesh = None
self.data_sharding = None
self.params_sharding = None
self.trainable_mask = None
self.initialise_model()
def initialise_model(self):
self.decode_fn = predict_fns.get_all(self.model)['decode']
self.decode = functools.partial(self.decode_fn, devices=jax.devices(), eos_token=self.tokenizer.eos_id())
def is_trainable_param(name, param):
if name.startswith("llm/layers/attn/"): return True
if name.startswith("llm/"): return False
if name.startswith("img/"): return False
raise ValueError(f"Unexpected param name {name}")
self.trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, self.params)
self.mesh = jax.sharding.Mesh(jax.devices(), ("data"))
self.data_sharding = jax.sharding.NamedSharding(
self.mesh, jax.sharding.PartitionSpec("data"))
self.params_sharding = big_vision.sharding.infer_sharding(
self.params, strategy=[('.*', 'fsdp(axis="data")')], mesh=self.mesh)
def preprocess_image(self,image, size=224):
image = np.asarray(image)
if image.ndim == 2: # Convert image without last channel into greyscale.
image = np.stack((image,)*3, axis=-1)
image = image[..., :3] # Remove alpha layer.
assert image.shape[-1] == 3
image = tf.constant(image)
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
return image.numpy() / 127.5 - 1.0
def preprocess_tokens(self, prefix, suffix=None, seqlen=None):
separator = "\n"
tokens = self.tokenizer.encode(prefix, add_bos=True) + self.tokenizer.encode(separator)
mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.
mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.
if suffix:
suffix = self.tokenizer.encode(suffix, add_eos=True)
tokens += suffix
mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.
mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.
mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding.
if seqlen:
padding = [0] * max(0, seqlen - len(tokens))
tokens = tokens[:seqlen] + padding
mask_ar = mask_ar[:seqlen] + padding
mask_loss = mask_loss[:seqlen] + padding
mask_input = mask_input[:seqlen] + padding
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))
def postprocess_tokens(self, tokens):
tokens = tokens.tolist() # np.array to list[int]
try: # Remove tokens at and after EOS if any.
eos_pos = tokens.index(self.tokenizer.eos_id())
tokens = tokens[:eos_pos]
except ValueError:
pass
return self.tokenizer.decode(tokens)
def split_and_keep_second_part(s):
parts = s.split('\n', 1)
if len(parts) > 1:
return parts[1]
return s
def data_iterator(self, image_bytes, caption):
image = Image.open(io.BytesIO(image_bytes))
image = self.preprocess_image(image)
tokens, mask_ar, _, mask_input = self.preprocess_tokens(caption, seqlen=SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_input": np.asarray(mask_input),
}
def make_predictions(self, data_iterator, *, num_examples=None,
batch_size=4, seqlen=SEQLEN, sampler="greedy"):
outputs = []
while True:
examples = []
try:
for _ in range(batch_size):
examples.append(next(data_iterator))
examples[-1]["_mask"] = np.array(True) # Indicates true example.
except StopIteration:
if len(examples) == 0:
return outputs
while len(examples) % batch_size:
examples.append(dict(examples[-1]))
examples[-1]["_mask"] = np.array(False) # Indicates padding example.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, self.data_sharding)
tokens = self.decode({"params": self.params}, batch=batch,
max_decode_len=seqlen, sampler=sampler)
# Fetch model predictions to device and detokenize.
tokens, mask = jax.device_get((tokens, batch["_mask"]))
tokens = tokens[mask] # remove padding examples.
responses = [self.postprocess_tokens(t) for t in tokens]
for example, response in zip(examples, responses):
outputs.append((example["image"], response))
if num_examples and len(outputs) >= num_examples:
return outputs
def process_result_to_bbox(self, image, caption, classes, w, h):
image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]
try:
detections = sv.Detections.from_lmm(
lmm='paligemma',
result=caption,
resolution_wh=(w, h),
classes=caption)
xyxy = list(detections.xyxy[0])
x1, y1, x2, y2 = xyxy[0], xyxy[1], xyxy[2], xyxy[3] #The number here could be result of 224x224
width = x2 - x1
height = y2 - y1
output = [x1, y1, width, height]
except Exception as e:
print('Error detection')
print(e)
output = [0,0,0,0]
return output
def predict(self, image: bytes, caption: str) -> List[int]:
image_original = Image.open(io.BytesIO(image))
original_width, original_height = image_original.size
if "detect" not in caption:
caption = f"detect {caption}"
# print("Making predictions...")
for image, response in self.make_predictions(self.data_iterator(image, caption), num_examples=1):
classes = response.replace("detect ", "")
output = self.process_result_to_bbox(image, response, classes, original_width, original_height)
return (output, response)
INFERENCE_IMAGE = '3_(backup)AdityaBY_img_14.png'
INFERENCE_PROMPT = "A mother takes a picture of her daughter holding a colourful wind spinner in front of the entrance."
TOKENIZER_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_tokenizer.model'
MODEL_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_segmentation.npz'
model_config = ml_collections.FrozenConfigDict({
"llm": {"vocab_size": 257_152},
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)
paligemma_manager = PaliGemmaManager(model, params, tokenizer)
with open(INFERENCE_IMAGE, 'rb') as f:
image_bytes = f.read()
output, response = paligemma_manager.predict(image_bytes,
INFERENCE_PROMPT)
image = Image.open(INFERENCE_IMAGE)
detections = sv.Detections.from_lmm(
lmm='paligemma',
result=response,
resolution_wh=image.size,
classes=response)
coordinates = detections.xyxy[0] # assuming we want the first detection
x1, y1, x2, y2 = coordinates
print('x1,y1,x2,y2:',coordinates)