import os import gradio as gr import PIL.Image import transformers from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor import torch import string import functools import re import flax.linen as nn import jax import jax.numpy as jnp import numpy as np import spaces model_id = "google/paligemma2-3b-mix-448" COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") HF_KEY = os.getenv("HF_KEY") if not HF_KEY: raise ValueError("Please set the HF_KEY environment variable with your Hugging Face API token") model = PaliGemmaForConditionalGeneration.from_pretrained( model_id, token=HF_KEY, trust_remote_code=True ).eval().to(device) processor = PaliGemmaProcessor.from_pretrained( model_id, token=HF_KEY, trust_remote_code=True) @spaces.GPU def infer(image: PIL.Image.Image, text: str, max_new_tokens: int) -> str: inputs = processor(text=text, images=image, return_tensors="pt").to(device) with torch.inference_mode(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False ) result = processor.batch_decode(generated_ids, skip_special_tokens=True) return result[0][len(text):].lstrip("\n") def parse_segmentation(input_image, input_text): out = infer(input_image, input_text, max_new_tokens=200) objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True) labels = set(obj.get('name') for obj in objs if obj.get('name')) color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] annotated_img = ( input_image, [ ( obj['mask'] if obj.get('mask') is not None else obj['xyxy'], obj['name'] or '', ) for obj in objs if 'mask' in obj or 'xyxy' in obj ], ) has_annotations = bool(annotated_img[1]) return annotated_img def _get_params(checkpoint): def transp(kernel): return np.transpose(kernel, (2, 3, 1, 0)) def conv(name): return { 'bias': checkpoint[name + '.bias'], 'kernel': transp(checkpoint[name + '.weight']), } def resblock(name): return { 'Conv_0': conv(name + '.0'), 'Conv_1': conv(name + '.2'), 'Conv_2': conv(name + '.4'), } return { '_embeddings': checkpoint['_vq_vae._embedding'], 'Conv_0': conv('decoder.0'), 'ResBlock_0': resblock('decoder.2.net'), 'ResBlock_1': resblock('decoder.3.net'), 'ConvTranspose_0': conv('decoder.4'), 'ConvTranspose_1': conv('decoder.6'), 'ConvTranspose_2': conv('decoder.8'), 'ConvTranspose_3': conv('decoder.10'), 'Conv_1': conv('decoder.12'), } def _quantized_values_from_codebook_indices(codebook_indices, embeddings): batch_size, num_tokens = codebook_indices.shape assert num_tokens == 16, codebook_indices.shape unused_num_embeddings, embedding_dim = embeddings.shape encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) return encodings @functools.cache def _get_reconstruct_masks(): class ResBlock(nn.Module): features: int @nn.compact def __call__(self, x): original_x = x x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) x = nn.relu(x) x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) x = nn.relu(x) x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) return x + original_x class Decoder(nn.Module): @nn.compact def __call__(self, x): num_res_blocks = 2 dim = 128 num_upsample_layers = 4 x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) x = nn.relu(x) for _ in range(num_res_blocks): x = ResBlock(features=dim)(x) for _ in range(num_upsample_layers): x = nn.ConvTranspose( features=dim, kernel_size=(4, 4), strides=(2, 2), padding=2, transpose_kernel=True, )(x) x = nn.relu(x) dim //= 2 x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) return x def reconstruct_masks(codebook_indices): quantized = _quantized_values_from_codebook_indices( codebook_indices, params['_embeddings'] ) return Decoder().apply({'params': params}, quantized) with open(_MODEL_PATH, 'rb') as f: params = _get_params(dict(np.load(f))) return jax.jit(reconstruct_masks, backend='cpu') _SEGMENT_DETECT_RE = re.compile( r'(.*?)' + r'' * 4 + r'\s*' + '(?:%s)?' % (r'' * 16) + r'\s*([^;<>]+)? ?(?:; )?', ) _MODEL_PATH = 'vae-oid.npz' def extract_objs(text, width, height, unique_labels=False): objs = [] seen = set() while text: m = _SEGMENT_DETECT_RE.match(text) if not m: break gs = list(m.groups()) before = gs.pop(0) name = gs.pop() y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width)) seg_indices = gs[4:20] if seg_indices[0] is None: mask = None else: seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32) m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0] m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1) m64 = PIL.Image.fromarray((m64 * 255).astype('uint8')) mask = np.zeros([height, width]) if y2 > y1 and x2 > x1: mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0 content = m.group() if before: objs.append(dict(content=before)) content = content[len(before):] while unique_labels and name in seen: name = (name or '') + "'" seen.add(name) objs.append(dict( content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name)) text = text[len(before) + len(content):] if text: objs.append(dict(content=text)) return objs with gr.Blocks() as demo: with gr.Tab("Text Generation"): with gr.Row(): with gr.Column(): image = gr.Image(type="pil", width=512, height=512) text_input = gr.Text(label="Input Text") with gr.Column(): text_output = gr.Text(label="Text Output") chat_btn = gr.Button() tokens = gr.Slider( label="Max New Tokens", minimum=10, maximum=200, value=20, step=10, ) chat_btn.click( fn=infer, inputs=[image, text_input, tokens], outputs=[text_output], ) with gr.Tab("Segment/Detect"): with gr.Row(): with gr.Column(): image = gr.Image(type="pil") seg_input = gr.Text(label="Entities to Segment/Detect") seg_btn = gr.Button("Submit") with gr.Column(): annotated_image = gr.AnnotatedImage(label="Output") seg_btn.click( fn=parse_segmentation, inputs=[image, seg_input], outputs=[annotated_image], ) if __name__ == "__main__": demo.queue(max_size=10).launch(debug=True)