''' 整体思路:对每一个claim调用一次目标检测器,汇总全部obejct(对相近的物体框进行删除 考虑剔除目标框or其他办法) 1. 对每一个claim调用detector 得到bouding box list;phrase list 2. 按woodpecker的方式 调用blip2 3. 按之前的方式调用ocr模型 4. 汇总时需汇总bouding box(相近的需删除) ''' import cv2 import yaml import torch import os import shortuuid from PIL import Image import numpy as np from torchvision.ops import box_convert from pipeline.tool.scene_text_model import * # import sys # sys.path.append("/home/wcx/wcx/EasyDetect/GroundingDINO") from pipeline.GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate BOX_TRESHOLD = 0.35 # used in detector api. TEXT_TRESHOLD = 0.25 # used in detector api. AREA_THRESHOLD = 0.001 # used to filter out too small object. IOU_THRESHOLD = 0.95 # used to filter the same instance. greater than threshold means the same instance class GroundingDINO: def __init__(self, config): self.config = config self.BOX_TRESHOLD = self.config["detector"]["BOX_TRESHOLD"] self.TEXT_TRESHOLD = self.config["detector"]["TEXT_TRESHOLD"] self.text_rec = MAERec() # load only one time self.model = load_model(self.config["detector"]["config"], self.config["detector"]["model"], device='cuda:0') def execute(self, image_path, content, new_path, use_text_rec): IMAGE_PATH = image_path image_source, image = load_image(IMAGE_PATH) if use_text_rec: # 在场景文本中下调boxthreshold boxes, logits, phrases = predict(model=self.model,image=image,caption=content,box_threshold=0.2,text_threshold=self.TEXT_TRESHOLD,device='cuda:0') h, w, _ = image_source.shape torch_boxes = boxes * torch.Tensor([w, h, w, h]) xyxy = box_convert(boxes=torch_boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() normed_xyxy = np.around(np.clip(xyxy / np.array([w, h, w, h]), 0., 1.), 3).tolist() dir_name = IMAGE_PATH.split("/")[-1][:-4] cache_dir = new_path + dir_name os.makedirs(cache_dir, exist_ok=True) res_list = [] for box, norm_box in zip(xyxy, normed_xyxy): # filter out too small object thre = AREA_THRESHOLD if (norm_box[2]-norm_box[0]) * (norm_box[3]-norm_box[1]) < 0.001: continue crop_id = shortuuid.uuid() crop_img = Image.fromarray(image_source).crop(box) crop_path = os.path.join(cache_dir, f"{crop_id}.jpg") crop_img.save(crop_path) _, res = self.text_rec.execute(crop_path) print(res) res_list.append(res) annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=res_list) new_id = shortuuid.uuid() new_image_path = os.path.join(cache_dir, f"{new_id}.jpg") cv2.imwrite(new_image_path, annotated_frame) result = {"boxes":normed_xyxy, "logits":logits, "phrases":res_list, "new_path":new_image_path} return result else: new_path = new_path + IMAGE_PATH.split('/')[-1] print(content) boxes, logits, phrases = predict(model=self.model,image=image,caption=content,box_threshold=self.BOX_TRESHOLD,text_threshold=self.TEXT_TRESHOLD,device='cuda:0') annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases) cv2.imwrite(new_path, annotated_frame) h, w, _ = image_source.shape torch_boxes = boxes * torch.Tensor([w, h, w, h]) xyxy = box_convert(boxes=torch_boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() normed_xyxy = np.around(np.clip(xyxy / np.array([w, h, w, h]), 0., 1.), 3).tolist() result = {"boxes":normed_xyxy, "logits":logits, "phrases":phrases, "new_path":new_path, "xyxy":xyxy, "image_source":image_source} return result if __name__ == '__main__': config = yaml.load(open("/home/wcx/wcx/GroundingDINO/LVLM/config/config.yaml", "r"), Loader=yaml.FullLoader) t = GroundingDINO(config=config) # /newdisk3/wcx/TextVQA/test_images/fca674d065b0ee2c.jpg # /newdisk3/wcx/TextVQA/test_images/6648410adb1b08cb.jpg image_path = "/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/image.jpg" #input = {"text":{"question":"Describe the image","answer":""},"image":image_path} # res = t.execute(image_path=image_path,content="word.number",new_path="/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/extra/",use_text_rec=True) # print(res) res2 = t.execute(image_path,content="car.man.glasses.coat",new_path="/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/extra/",use_text_rec=False) print(res2) ''' dog cat [[0.107, 0.005, 0.56, 0.999], [0.597, 0.066, 1.0, 0.953]] 'basketball', 'boy', 'car' [0.741, 0.179, 0.848, 0.285], [0.773, 0.299, 0.98, 0.828], [0.001, 0.304, 0.992, 0.854] 'worlld [0.405, 0.504, 0.726, 0.7] ''' """ cloud.agricultural exhibit.music.sky.food vendor.sign.street sign.carnival ride /val2014/COCO_val2014_000000029056.jpg """