|
|
|
import os |
|
import io |
|
import jax |
|
import base64 |
|
import warnings |
|
import functools |
|
import numpy as np |
|
import sentencepiece |
|
import ml_collections |
|
from PIL import Image |
|
import big_vision.utils |
|
import tensorflow as tf |
|
import supervision as sv |
|
import big_vision.sharding |
|
from typing import Tuple, List, Optional |
|
from big_vision.models.proj.paligemma import paligemma |
|
from big_vision.trainers.proj.paligemma import predict_fns |
|
|
|
|
|
|
|
SEQLEN = 128 |
|
|
|
class PaliGemmaManager: |
|
_instance = None |
|
|
|
def __new__(cls, *args, **kwargs): |
|
if not cls._instance: |
|
cls._instance = super(PaliGemmaManager, cls).__new__(cls) |
|
return cls._instance |
|
|
|
def __init__(self, model, params, tokenizer): |
|
self.model = model |
|
self.params = params |
|
self.tokenizer = tokenizer |
|
self.decode_fn = None |
|
self.decode = None |
|
self.mesh = None |
|
self.data_sharding = None |
|
self.params_sharding = None |
|
self.trainable_mask = None |
|
|
|
self.initialise_model() |
|
|
|
|
|
def initialise_model(self): |
|
self.decode_fn = predict_fns.get_all(self.model)['decode'] |
|
self.decode = functools.partial(self.decode_fn, devices=jax.devices(), eos_token=self.tokenizer.eos_id()) |
|
|
|
def is_trainable_param(name, param): |
|
if name.startswith("llm/layers/attn/"): return True |
|
if name.startswith("llm/"): return False |
|
if name.startswith("img/"): return False |
|
raise ValueError(f"Unexpected param name {name}") |
|
self.trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, self.params) |
|
|
|
self.mesh = jax.sharding.Mesh(jax.devices(), ("data")) |
|
|
|
self.data_sharding = jax.sharding.NamedSharding( |
|
self.mesh, jax.sharding.PartitionSpec("data")) |
|
|
|
self.params_sharding = big_vision.sharding.infer_sharding( |
|
self.params, strategy=[('.*', 'fsdp(axis="data")')], mesh=self.mesh) |
|
def preprocess_image(self,image, size=224): |
|
image = np.asarray(image) |
|
if image.ndim == 2: |
|
image = np.stack((image,)*3, axis=-1) |
|
|
|
image = image[..., :3] |
|
assert image.shape[-1] == 3 |
|
|
|
image = tf.constant(image) |
|
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True) |
|
return image.numpy() / 127.5 - 1.0 |
|
|
|
def preprocess_tokens(self, prefix, suffix=None, seqlen=None): |
|
separator = "\n" |
|
tokens = self.tokenizer.encode(prefix, add_bos=True) + self.tokenizer.encode(separator) |
|
mask_ar = [0] * len(tokens) |
|
mask_loss = [0] * len(tokens) |
|
|
|
if suffix: |
|
suffix = self.tokenizer.encode(suffix, add_eos=True) |
|
tokens += suffix |
|
mask_ar += [1] * len(suffix) |
|
mask_loss += [1] * len(suffix) |
|
|
|
mask_input = [1] * len(tokens) |
|
if seqlen: |
|
padding = [0] * max(0, seqlen - len(tokens)) |
|
tokens = tokens[:seqlen] + padding |
|
mask_ar = mask_ar[:seqlen] + padding |
|
mask_loss = mask_loss[:seqlen] + padding |
|
mask_input = mask_input[:seqlen] + padding |
|
|
|
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input)) |
|
|
|
def postprocess_tokens(self, tokens): |
|
tokens = tokens.tolist() |
|
try: |
|
eos_pos = tokens.index(self.tokenizer.eos_id()) |
|
tokens = tokens[:eos_pos] |
|
except ValueError: |
|
pass |
|
return self.tokenizer.decode(tokens) |
|
|
|
def split_and_keep_second_part(s): |
|
parts = s.split('\n', 1) |
|
if len(parts) > 1: |
|
return parts[1] |
|
return s |
|
|
|
def data_iterator(self, image_bytes, caption): |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
image = self.preprocess_image(image) |
|
tokens, mask_ar, _, mask_input = self.preprocess_tokens(caption, seqlen=SEQLEN) |
|
|
|
yield { |
|
"image": np.asarray(image), |
|
"text": np.asarray(tokens), |
|
"mask_ar": np.asarray(mask_ar), |
|
"mask_input": np.asarray(mask_input), |
|
} |
|
|
|
def make_predictions(self, data_iterator, *, num_examples=None, |
|
batch_size=4, seqlen=SEQLEN, sampler="greedy"): |
|
outputs = [] |
|
while True: |
|
examples = [] |
|
try: |
|
for _ in range(batch_size): |
|
examples.append(next(data_iterator)) |
|
examples[-1]["_mask"] = np.array(True) |
|
except StopIteration: |
|
if len(examples) == 0: |
|
return outputs |
|
|
|
|
|
while len(examples) % batch_size: |
|
examples.append(dict(examples[-1])) |
|
examples[-1]["_mask"] = np.array(False) |
|
|
|
|
|
batch = jax.tree.map(lambda *x: np.stack(x), *examples) |
|
batch = big_vision.utils.reshard(batch, self.data_sharding) |
|
tokens = self.decode({"params": self.params}, batch=batch, |
|
max_decode_len=seqlen, sampler=sampler) |
|
|
|
|
|
tokens, mask = jax.device_get((tokens, batch["_mask"])) |
|
tokens = tokens[mask] |
|
responses = [self.postprocess_tokens(t) for t in tokens] |
|
|
|
for example, response in zip(examples, responses): |
|
outputs.append((example["image"], response)) |
|
if num_examples and len(outputs) >= num_examples: |
|
return outputs |
|
|
|
def process_result_to_bbox(self, image, caption, classes, w, h): |
|
image = ((image + 1)/2 * 255).astype(np.uint8) |
|
|
|
try: |
|
detections = sv.Detections.from_lmm( |
|
lmm='paligemma', |
|
result=caption, |
|
resolution_wh=(w, h), |
|
classes=caption) |
|
|
|
xyxy = list(detections.xyxy[0]) |
|
x1, y1, x2, y2 = xyxy[0], xyxy[1], xyxy[2], xyxy[3] |
|
width = x2 - x1 |
|
height = y2 - y1 |
|
output = [x1, y1, width, height] |
|
except Exception as e: |
|
print('Error detection') |
|
print(e) |
|
output = [0,0,0,0] |
|
|
|
return output |
|
|
|
def predict(self, image: bytes, caption: str) -> List[int]: |
|
image_original = Image.open(io.BytesIO(image)) |
|
original_width, original_height = image_original.size |
|
if "detect" not in caption: |
|
caption = f"detect {caption}" |
|
|
|
for image, response in self.make_predictions(self.data_iterator(image, caption), num_examples=1): |
|
classes = response.replace("detect ", "") |
|
output = self.process_result_to_bbox(image, response, classes, original_width, original_height) |
|
|
|
return (output, response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
INFERENCE_IMAGE = '3_(backup)AdityaBY_img_14.png' |
|
INFERENCE_PROMPT = "A mother takes a picture of her daughter holding a colourful wind spinner in front of the entrance." |
|
|
|
|
|
|
|
|
|
TOKENIZER_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_tokenizer.model' |
|
MODEL_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_segmentation.npz' |
|
|
|
|
|
model_config = ml_collections.FrozenConfigDict({ |
|
"llm": {"vocab_size": 257_152}, |
|
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"} |
|
}) |
|
model = paligemma.Model(**model_config) |
|
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH) |
|
|
|
|
|
params = paligemma.load(None, MODEL_PATH, model_config) |
|
paligemma_manager = PaliGemmaManager(model, params, tokenizer) |
|
|
|
with open(INFERENCE_IMAGE, 'rb') as f: |
|
image_bytes = f.read() |
|
|
|
output, response = paligemma_manager.predict(image_bytes, |
|
INFERENCE_PROMPT) |
|
image = Image.open(INFERENCE_IMAGE) |
|
detections = sv.Detections.from_lmm( |
|
lmm='paligemma', |
|
result=response, |
|
resolution_wh=image.size, |
|
classes=response) |
|
|
|
coordinates = detections.xyxy[0] |
|
x1, y1, x2, y2 = coordinates |
|
|
|
print('x1,y1,x2,y2:',coordinates) |