import random import panel as pn import requests from PIL import Image from transformers import CLIPProcessor, CLIPModel from typing import List, Tuple def set_random_url(_): pet = random.choice(["cat", "dog"]) api_url = f"https://api.the{pet}api.com/v1/images/search" with requests.get(api_url) as resp: resp.raise_for_status() url = resp.json()[0]["url"] image_url.value = url @pn.cache def load_processor_model( processor_name: str, model_name: str ) -> Tuple[CLIPProcessor, CLIPModel]: processor = CLIPProcessor.from_pretrained(processor_name) model = CLIPModel.from_pretrained(model_name) return processor, model @pn.cache def open_image_url(image_url: str) -> Image: with requests.get(image_url, stream=True) as resp: resp.raise_for_status() image = Image.open(resp.raw) return image def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: processor, model = load_processor_model( "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" ) inputs = processor( text=class_items, images=[image], return_tensors="pt", # pytorch tensors ) outputs = model(**inputs) logits_per_image = outputs.logits_per_image class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() return class_likelihoods[0] def process_inputs(class_names: List[str], image_url: str): """ High level function that takes in the user inputs and returns the classification results as panel objects. """ image = open_image_url(image_url) class_items = class_names.split(",") class_likelihoods = get_similarity_scores(class_items, image) # build the results column results_column = pn.Column("## 🎉 Here are the results!") results_column.append( pn.pane.Image(image, max_width=698, sizing_mode="scale_width") ) for class_item, class_likelihood in zip(class_items, class_likelihoods): row_label = pn.widgets.StaticText( name=class_item.strip(), value=f"{class_likelihood:.2%}", margin=(0, 10) ) row_bar = pn.indicators.Progress( max=100, value=int(class_likelihood * 100), sizing_mode="stretch_width", bar_color="secondary", margin=(0, 10), ) row_column = pn.Column(row_label, row_bar) results_column.append(row_column) return results_column # create widgets randomize_url = pn.widgets.Button(name="Randomize URL", align="end") image_url = pn.widgets.TextInput( name="Image URL to classify", value="https://cdn2.thecatapi.com/images/cct.jpg", ) class_names = pn.widgets.TextInput( name="Comma separated class names", placeholder="Enter possible class names, e.g. cat, dog", value="cat, dog, parrot", ) input_widgets = pn.Column( "## 😊 Click randomize or paste a URL to start classifying!", pn.Row(image_url, randomize_url), class_names, ) # add interactivity randomize_url.on_click(set_random_url) interactive_result = pn.panel( pn.bind( process_inputs, image_url=image_url, class_names=class_names ), loading_indicator=True ) # create dashboard main = pn.WidgetBox( input_widgets, interactive_result, ) pn.template.BootstrapTemplate( title="Panel Image Classification Demo", main=main, main_max_width="min(50%, 698px)", header_background="#F08080", ).servable(title="Panel Image Classification Demo")