import os import gradio as gr import PIL.Image import transformers from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor import torch import string import functools import re import flax.linen as nn import jax import jax.numpy as jnp import numpy as np import spaces model_id = "gv-hf/paligemma2-10b-mix-448" COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device) processor = PaliGemmaProcessor.from_pretrained(model_id) ###### Transformers Inference @spaces.GPU def infer( image: PIL.Image.Image, text: str, max_new_tokens: int ) -> str: inputs = processor(text=text, images=image, return_tensors="pt").to(device) with torch.inference_mode(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False ) result = processor.batch_decode(generated_ids, skip_special_tokens=True) return result[0][len(text):].lstrip("\n") ##### Parse segmentation output tokens into masks ##### Also returns bounding boxes with their labels def parse_segmentation(input_image, input_text): out = infer(input_image, input_text, max_new_tokens=200) objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True) labels = set(obj.get('name') for obj in objs if obj.get('name')) color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] annotated_img = ( input_image, [ ( obj['mask'] if obj.get('mask') is not None else obj['xyxy'], obj['name'] or '', ) for obj in objs if 'mask' in obj or 'xyxy' in obj ], ) has_annotations = bool(annotated_img[1]) return annotated_img ######## Demo INTRO_TEXT = """## PaliGemma 2 demo\n\n | [Github](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) | [Blogpost](https://huggingface.co/blog/paligemma2) | [Arxiv](https://arxiv.org/abs/2412.03555) |\n\n PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question answering, text reading, object detection and object segmentation. \n\n This space includes models fine-tuned on a mix of downstream tasks, **inferred via 🤗 transformers**. See the [Blogpost](https://huggingface.co/blog/paligemma2) and [README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) for detailed information how to use and fine-tune PaliGemma models. \n\n **This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications. """ with gr.Blocks(theme=gr.themes.Soft(),css="style.css") as demo: gr.Markdown(INTRO_TEXT) with gr.Tab("Text Generation"): with gr.Row(): with gr.Column(): image = gr.Image(type="pil", width=512, height=512) text_input = gr.Text(label="Input Text") with gr.Column(): text_output = gr.Text(label="Text Output") chat_btn = gr.Button() tokens = gr.Slider( label="Max New Tokens", info="Set to larger for longer generation.", minimum=10, maximum=200, value=20, step=10, ) chat_inputs = [ image, text_input, tokens ] chat_outputs = [ text_output ] chat_btn.click( fn=infer, inputs=chat_inputs, outputs=chat_outputs, ) examples = [["./examples/password.jpg", "what is the password", 10], ["./examples/menu.JPG", "read text", 200], ["./examples/mosque.jpg", "describe the image in great detail", 200], ["./examples/infovqa.png", "what is the targeted emission reduction for France for 2023", 10], ["./examples/chartqa.png", "for resolution-sensitive tasks, which variant is best", 50], ["./examples/fiche.jpg", "When is this ticket dated and how much did it cost?", 20], ["./examples/howto.jpg", "What does this image show?", 100], ["./examples/billard1.jpg", "How many white/yellow balls are there?", 10], ["./examples/bowie.jpg", "Who is this?", 10], ["./examples/ulges.jpg", "Who is the author of this book?", 10]] gr.Markdown("Example images are licensed CC0 by [akolesnikoff@](https://github.com/akolesnikoff), [mbosnjak@](https://github.com/mbosnjak), [maximneumann@](https://github.com/maximneumann) and [merve](https://huggingface.co/merve).") gr.Examples( examples=examples, inputs=chat_inputs, ) with gr.Tab("Segment/Detect"): with gr.Row(): with gr.Column(): image = gr.Image(type="pil") seg_input = gr.Text(label="Entities to Segment/Detect") seg_btn = gr.Button("Submit") with gr.Column(): annotated_image = gr.AnnotatedImage(label="Output") examples = [["./examples/venice.jpg", "detect bird"], ["./examples/cats.png", "segment cat behind"], ["./examples/bee.jpg", "detect red flower"], ["./examples/barsik.jpg", "segment cat"], ] gr.Markdown("Example images are licensed CC0 by [akolesnikoff@](https://github.com/akolesnikoff), [mbosnjak@](https://github.com/mbosnjak), [maximneumann@](https://github.com/maximneumann) and [merve](https://huggingface.co/merve).") gr.Examples( examples=examples, inputs=[image, seg_input], ) seg_inputs = [ image, seg_input ] seg_outputs = [ annotated_image ] seg_btn.click( fn=parse_segmentation, inputs=seg_inputs, outputs=seg_outputs, ) ### Postprocessing Utils for Segmentation Tokens ### Segmentation tokens are passed to another VAE which decodes them to a mask _MODEL_PATH = 'vae-oid.npz' _SEGMENT_DETECT_RE = re.compile( r'(.*?)' + r'' * 4 + r'\s*' + '(?:%s)?' % (r'' * 16) + r'\s*([^;<>]+)? ?(?:; )?', ) def _get_params(checkpoint): """Converts PyTorch checkpoint to Flax params.""" def transp(kernel): return np.transpose(kernel, (2, 3, 1, 0)) def conv(name): return { 'bias': checkpoint[name + '.bias'], 'kernel': transp(checkpoint[name + '.weight']), } def resblock(name): return { 'Conv_0': conv(name + '.0'), 'Conv_1': conv(name + '.2'), 'Conv_2': conv(name + '.4'), } return { '_embeddings': checkpoint['_vq_vae._embedding'], 'Conv_0': conv('decoder.0'), 'ResBlock_0': resblock('decoder.2.net'), 'ResBlock_1': resblock('decoder.3.net'), 'ConvTranspose_0': conv('decoder.4'), 'ConvTranspose_1': conv('decoder.6'), 'ConvTranspose_2': conv('decoder.8'), 'ConvTranspose_3': conv('decoder.10'), 'Conv_1': conv('decoder.12'), } def _quantized_values_from_codebook_indices(codebook_indices, embeddings): batch_size, num_tokens = codebook_indices.shape assert num_tokens == 16, codebook_indices.shape unused_num_embeddings, embedding_dim = embeddings.shape encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) return encodings @functools.cache def _get_reconstruct_masks(): """Reconstructs masks from codebook indices. Returns: A function that expects indices shaped `[B, 16]` of dtype int32, each ranging from 0 to 127 (inclusive), and that returns a decoded masks sized `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1]. """ class ResBlock(nn.Module): features: int @nn.compact def __call__(self, x): original_x = x x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) x = nn.relu(x) x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) x = nn.relu(x) x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) return x + original_x class Decoder(nn.Module): """Upscales quantized vectors to mask.""" @nn.compact def __call__(self, x): num_res_blocks = 2 dim = 128 num_upsample_layers = 4 x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) x = nn.relu(x) for _ in range(num_res_blocks): x = ResBlock(features=dim)(x) for _ in range(num_upsample_layers): x = nn.ConvTranspose( features=dim, kernel_size=(4, 4), strides=(2, 2), padding=2, transpose_kernel=True, )(x) x = nn.relu(x) dim //= 2 x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) return x def reconstruct_masks(codebook_indices): quantized = _quantized_values_from_codebook_indices( codebook_indices, params['_embeddings'] ) return Decoder().apply({'params': params}, quantized) with open(_MODEL_PATH, 'rb') as f: params = _get_params(dict(np.load(f))) return jax.jit(reconstruct_masks, backend='cpu') def extract_objs(text, width, height, unique_labels=False): """Returns objs for a string with "" and "" tokens.""" objs = [] seen = set() while text: m = _SEGMENT_DETECT_RE.match(text) if not m: break print("m", m) gs = list(m.groups()) before = gs.pop(0) name = gs.pop() y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width)) seg_indices = gs[4:20] if seg_indices[0] is None: mask = None else: seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32) m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0] m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1) m64 = PIL.Image.fromarray((m64 * 255).astype('uint8')) mask = np.zeros([height, width]) if y2 > y1 and x2 > x1: mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0 content = m.group() if before: objs.append(dict(content=before)) content = content[len(before):] while unique_labels and name in seen: name = (name or '') + "'" seen.add(name) objs.append(dict( content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name)) text = text[len(before) + len(content):] if text: objs.append(dict(content=text)) return objs ######### if __name__ == "__main__": demo.queue(max_size=10).launch(debug=True)