import os import io import jax import base64 import warnings import functools import numpy as np import sentencepiece import ml_collections from PIL import Image import big_vision.utils import tensorflow as tf import supervision as sv import big_vision.sharding from typing import Tuple, List, Optional from big_vision.models.proj.paligemma import paligemma from big_vision.trainers.proj.paligemma import predict_fns SEQLEN = 128 class PaliGemmaManager: _instance = None def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super(PaliGemmaManager, cls).__new__(cls) return cls._instance def __init__(self, model, params, tokenizer): self.model = model self.params = params self.tokenizer = tokenizer self.decode_fn = None self.decode = None self.mesh = None self.data_sharding = None self.params_sharding = None self.trainable_mask = None self.initialise_model() def initialise_model(self): self.decode_fn = predict_fns.get_all(self.model)['decode'] self.decode = functools.partial(self.decode_fn, devices=jax.devices(), eos_token=self.tokenizer.eos_id()) def is_trainable_param(name, param): if name.startswith("llm/layers/attn/"): return True if name.startswith("llm/"): return False if name.startswith("img/"): return False raise ValueError(f"Unexpected param name {name}") self.trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, self.params) self.mesh = jax.sharding.Mesh(jax.devices(), ("data")) self.data_sharding = jax.sharding.NamedSharding( self.mesh, jax.sharding.PartitionSpec("data")) self.params_sharding = big_vision.sharding.infer_sharding( self.params, strategy=[('.*', 'fsdp(axis="data")')], mesh=self.mesh) def preprocess_image(self,image, size=224): image = np.asarray(image) if image.ndim == 2: # Convert image without last channel into greyscale. image = np.stack((image,)*3, axis=-1) image = image[..., :3] # Remove alpha layer. assert image.shape[-1] == 3 image = tf.constant(image) image = tf.image.resize(image, (size, size), method='bilinear', antialias=True) return image.numpy() / 127.5 - 1.0 def preprocess_tokens(self, prefix, suffix=None, seqlen=None): separator = "\n" tokens = self.tokenizer.encode(prefix, add_bos=True) + self.tokenizer.encode(separator) mask_ar = [0] * len(tokens) # 0 to use full attention for prefix. mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss. if suffix: suffix = self.tokenizer.encode(suffix, add_eos=True) tokens += suffix mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix. mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss. mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding. if seqlen: padding = [0] * max(0, seqlen - len(tokens)) tokens = tokens[:seqlen] + padding mask_ar = mask_ar[:seqlen] + padding mask_loss = mask_loss[:seqlen] + padding mask_input = mask_input[:seqlen] + padding return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input)) def postprocess_tokens(self, tokens): tokens = tokens.tolist() # np.array to list[int] try: # Remove tokens at and after EOS if any. eos_pos = tokens.index(self.tokenizer.eos_id()) tokens = tokens[:eos_pos] except ValueError: pass return self.tokenizer.decode(tokens) def split_and_keep_second_part(s): parts = s.split('\n', 1) if len(parts) > 1: return parts[1] return s def data_iterator(self, image_bytes, caption): image = Image.open(io.BytesIO(image_bytes)) image = self.preprocess_image(image) tokens, mask_ar, _, mask_input = self.preprocess_tokens(caption, seqlen=SEQLEN) yield { "image": np.asarray(image), "text": np.asarray(tokens), "mask_ar": np.asarray(mask_ar), "mask_input": np.asarray(mask_input), } def make_predictions(self, data_iterator, *, num_examples=None, batch_size=4, seqlen=SEQLEN, sampler="greedy"): outputs = [] while True: examples = [] try: for _ in range(batch_size): examples.append(next(data_iterator)) examples[-1]["_mask"] = np.array(True) # Indicates true example. except StopIteration: if len(examples) == 0: return outputs while len(examples) % batch_size: examples.append(dict(examples[-1])) examples[-1]["_mask"] = np.array(False) # Indicates padding example. batch = jax.tree.map(lambda *x: np.stack(x), *examples) batch = big_vision.utils.reshard(batch, self.data_sharding) tokens = self.decode({"params": self.params}, batch=batch, max_decode_len=seqlen, sampler=sampler) # Fetch model predictions to device and detokenize. tokens, mask = jax.device_get((tokens, batch["_mask"])) tokens = tokens[mask] # remove padding examples. responses = [self.postprocess_tokens(t) for t in tokens] for example, response in zip(examples, responses): outputs.append((example["image"], response)) if num_examples and len(outputs) >= num_examples: return outputs def process_result_to_bbox(self, image, caption, classes, w, h): image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255] try: detections = sv.Detections.from_lmm( lmm='paligemma', result=caption, resolution_wh=(w, h), classes=caption) xyxy = list(detections.xyxy[0]) x1, y1, x2, y2 = xyxy[0], xyxy[1], xyxy[2], xyxy[3] #The number here could be result of 224x224 width = x2 - x1 height = y2 - y1 output = [x1, y1, width, height] except Exception as e: print('Error detection') print(e) output = [0,0,0,0] return output def predict(self, image: bytes, caption: str) -> List[int]: image_original = Image.open(io.BytesIO(image)) original_width, original_height = image_original.size if "detect" not in caption: caption = f"detect {caption}" # print("Making predictions...") for image, response in self.make_predictions(self.data_iterator(image, caption), num_examples=1): classes = response.replace("detect ", "") output = self.process_result_to_bbox(image, response, classes, original_width, original_height) return (output, response) INFERENCE_IMAGE = '3_(backup)AdityaBY_img_14.png' INFERENCE_PROMPT = "A mother takes a picture of her daughter holding a colourful wind spinner in front of the entrance." TOKENIZER_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_tokenizer.model' MODEL_PATH = '/home/lyka/air/Paligemma/pali-package/pali_open_vocab_annotations_segmentation.npz' model_config = ml_collections.FrozenConfigDict({ "llm": {"vocab_size": 257_152}, "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"} }) model = paligemma.Model(**model_config) tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH) # Load params - this can take up to 1 minute in T4 colabs. params = paligemma.load(None, MODEL_PATH, model_config) paligemma_manager = PaliGemmaManager(model, params, tokenizer) with open(INFERENCE_IMAGE, 'rb') as f: image_bytes = f.read() output, response = paligemma_manager.predict(image_bytes, INFERENCE_PROMPT) image = Image.open(INFERENCE_IMAGE) detections = sv.Detections.from_lmm( lmm='paligemma', result=response, resolution_wh=image.size, classes=response) coordinates = detections.xyxy[0] # assuming we want the first detection x1, y1, x2, y2 = coordinates print('x1,y1,x2,y2:',coordinates)