Spaces:
Running
on
T4
Running
on
T4
"""PaliGemma demo gradio app.""" | |
import datetime | |
import functools | |
import glob | |
import json | |
import logging | |
import os | |
import time | |
import gradio as gr | |
import jax | |
import PIL.Image | |
import gradio_helpers | |
import models | |
import paligemma_parse | |
INTRO_TEXT = """🤲 PaliGemma demo\n\n | |
| [Paper](https://arxiv.org/abs/2407.07726) | |
| [GitHub](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) | |
| [HF blog post](https://huggingface.co/blog/paligemma) | |
| [Google blog post](https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024) | |
| [Vertex AI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) | |
| [Demo](https://huggingface.co/spaces/google/paligemma) | |
|\n\n | |
[PaliGemma](https://ai.google.dev/gemma/docs/paligemma) is an open vision-language model by Google, | |
inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and | |
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) | |
vision model and the [Gemma](https://arxiv.org/abs/2403.08295) language model. PaliGemma is designed as a versatile | |
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question | |
answering, text reading, object detection and object segmentation. | |
\n\n | |
This space includes models fine-tuned on a mix of downstream tasks. | |
See the [blog post](https://huggingface.co/blog/paligemma) and | |
[README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) | |
for detailed information how to use and fine-tune PaliGemma models. | |
\n\n | |
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications. | |
""" | |
make_image = lambda value, visible: gr.Image( | |
value, label='Image', type='filepath', visible=visible) | |
make_annotated_image = functools.partial(gr.AnnotatedImage, label='Image') | |
make_highlighted_text = functools.partial(gr.HighlightedText, label='Output') | |
# https://coolors.co/4285f4-db4437-f4b400-0f9d58-e48ef1 | |
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] | |
def compute(image, prompt, model_name, sampler): | |
"""Runs model inference.""" | |
if image is None: | |
raise gr.Error('Image required') | |
logging.info('prompt="%s"', prompt) | |
if isinstance(image, str): | |
image = PIL.Image.open(image) | |
if gradio_helpers.should_mock(): | |
logging.warning('Mocking response') | |
time.sleep(2.) | |
output = paligemma_parse.EXAMPLE_STRING | |
else: | |
if not model_name: | |
raise gr.Error('Models not loaded yet') | |
output = models.generate(model_name, sampler, image, prompt) | |
logging.info('output="%s"', output) | |
width, height = image.size | |
objs = paligemma_parse.extract_objs(output, width, height, unique_labels=True) | |
labels = set(obj.get('name') for obj in objs if obj.get('name')) | |
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} | |
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] | |
annotated_image = ( | |
image, | |
[ | |
( | |
obj['mask'] if obj.get('mask') is not None else obj['xyxy'], | |
obj['name'] or '', | |
) | |
for obj in objs | |
if 'mask' in obj or 'xyxy' in obj | |
], | |
) | |
has_annotations = bool(annotated_image[1]) | |
return ( | |
make_highlighted_text( | |
highlighted_text, visible=True, color_map=color_map), | |
make_image(image, visible=not has_annotations), | |
make_annotated_image( | |
annotated_image, visible=has_annotations, width=width, height=height, | |
color_map=color_map), | |
) | |
def warmup(model_name): | |
image = PIL.Image.new('RGB', [1, 1]) | |
_ = compute(image, '', model_name, 'greedy') | |
def reset(): | |
return ( | |
'', make_highlighted_text('', visible=False), | |
make_image(None, visible=True), make_annotated_image(None, visible=False), | |
) | |
def create_app(): | |
"""Creates demo UI.""" | |
make_model = lambda choices: gr.Dropdown( | |
value=(choices + [''])[0], | |
choices=choices, | |
label='Model', | |
visible=bool(choices), | |
) | |
make_prompt = lambda value, visible=True: gr.Textbox( | |
value, label='Prompt', visible=visible) | |
with gr.Blocks() as demo: | |
##### Main UI structure. | |
gr.Markdown(INTRO_TEXT) | |
with gr.Row(): | |
image = make_image(None, visible=True) # input | |
annotated_image = make_annotated_image(None, visible=False) # output | |
with gr.Column(): | |
with gr.Row(): | |
prompt = make_prompt('', visible=True) | |
model_info = gr.Markdown(label='Model Info') | |
with gr.Row(): | |
model = make_model([]) | |
samplers = [ | |
'greedy', 'nucleus(0.1)', 'nucleus(0.3)', 'temperature(0.5)'] | |
sampler = gr.Dropdown( | |
value=samplers[0], choices=samplers, label='Decoding' | |
) | |
with gr.Row(): | |
run = gr.Button('Run', variant='primary') | |
clear = gr.Button('Clear') | |
highlighted_text = make_highlighted_text('', visible=False) | |
##### UI logic. | |
def update_ui(model, prompt): | |
prompt = make_prompt(prompt, visible=True) | |
model_info = f'Model `{model}` – {models.MODELS_INFO.get(model, "No info.")}' | |
return [prompt, model_info] | |
gr.on( | |
[model.change], | |
update_ui, | |
[model, prompt], | |
[prompt, model_info], | |
) | |
gr.on( | |
[run.click, prompt.submit], | |
compute, | |
[image, prompt, model, sampler], | |
[highlighted_text, image, annotated_image], | |
) | |
clear.click( | |
reset, None, [prompt, highlighted_text, image, annotated_image] | |
) | |
##### Examples. | |
gr.set_static_paths(['examples/']) | |
all_examples = [json.load(open(p)) for p in glob.glob('examples/*.json')] | |
logging.info('loaded %d examples', len(all_examples)) | |
example_image = gr.Image( | |
label='Image', visible=False) # proxy, never visible | |
example_model = gr.Text( | |
label='Model', visible=False) # proxy, never visible | |
example_prompt = gr.Text( | |
label='Prompt', visible=False) # proxy, never visible | |
example_license = gr.Markdown( | |
label='Image License', visible=False) # placeholder, never visible | |
gr.Examples( | |
examples=[ | |
[ | |
f'examples/{ex["name"]}.jpg', | |
ex['prompt'], | |
ex['model'], | |
ex['license'], | |
] | |
for ex in all_examples | |
if ex['model'] in models.MODELS | |
], | |
inputs=[example_image, example_prompt, example_model, example_license], | |
) | |
##### Examples UI logic. | |
example_image.change( | |
lambda image_path: ( | |
make_image(image_path, visible=True), | |
make_annotated_image(None, visible=False), | |
make_highlighted_text('', visible=False), | |
), | |
example_image, | |
[image, annotated_image, highlighted_text], | |
) | |
def example_model_changed(model): | |
if model not in gradio_helpers.get_paths(): | |
raise gr.Error(f'Model "{model}" not loaded!') | |
return model | |
example_model.change(example_model_changed, example_model, model) | |
example_prompt.change(make_prompt, example_prompt, prompt) | |
##### Status. | |
status = gr.Markdown(f'Startup: {datetime.datetime.now()}') | |
gpu_kind = gr.Markdown(f'GPU=?') | |
demo.load( | |
lambda: [ | |
gradio_helpers.get_status(), | |
make_model(list(gradio_helpers.get_paths())), | |
], | |
None, | |
[status, model], | |
) | |
def get_gpu_kind(): | |
device = jax.devices()[0] | |
if not gradio_helpers.should_mock() and device.platform != 'gpu': | |
raise gr.Error('GPU not visible to JAX!') | |
return f'GPU={device.device_kind}' | |
demo.load(get_gpu_kind, None, gpu_kind) | |
return demo | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s') | |
logging.info('JAX devices: %s', jax.devices()) | |
for k, v in os.environ.items(): | |
logging.info('environ["%s"] = %r', k, v) | |
gradio_helpers.set_warmup_function(warmup) | |
for name, (repo, filename, revision) in models.MODELS.items(): | |
gradio_helpers.register_download(name, repo, filename, revision) | |
create_app().queue().launch() | |