import streamlit as st import yaml import torch import torch.nn.functional as F from transformers import DetrImageProcessor, DetrForObjectDetection from lib.IRRA.tokenizer import tokenize, SimpleTokenizer from lib.IRRA.image import prepare_images from lib.IRRA.model.build import build_model, IRRA from PIL import Image from pathlib import Path from easydict import EasyDict @st.cache_resource def get_model(): args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader) args = EasyDict(args) args['training'] = False model = build_model(args) return model @st.cache_resource def get_detr(): processor = DetrImageProcessor.from_pretrained( "facebook/detr-resnet-50", revision="no_timm") model = DetrForObjectDetection.from_pretrained( "facebook/detr-resnet-50", revision="no_timm") return model, processor def segment_images(model, processor, images: list[str]): segments = [] id = 0 p = Path('segments') p.mkdir(exist_ok=True) for image in images: image = Image.open(image) inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=0.9)[0] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): box = [round(i, 2) for i in box.tolist()] label = model.config.id2label[label.item()] if box[2] - box[0] > 70 and box[3] - box[1] > 70: if label == 'person': file = p / f'img_{id}.jpg' image.crop(box).save(file) segments.append(file.as_posix()) id += 1 return segments def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor: tokenizer = SimpleTokenizer() txt = tokenize(text, tokenizer) imgs = prepare_images(images) image_feats = model.encode_image(imgs) text_feats = model.encode_text(txt.unsqueeze(0)) image_feats = F.normalize(image_feats, p=2, dim=1) text_feats = F.normalize(text_feats, p=2, dim=1) return text_feats @ image_feats.t()