Spaces:
Sleeping
Sleeping
''' | |
整体思路:对每一个claim调用一次目标检测器,汇总全部obejct(对相近的物体框进行删除 考虑剔除目标框or其他办法) | |
1. 对每一个claim调用detector 得到bouding box list;phrase list | |
2. 按woodpecker的方式 调用blip2 | |
3. 按之前的方式调用ocr模型 | |
4. 汇总时需汇总bouding box(相近的需删除) | |
''' | |
import cv2 | |
import yaml | |
import torch | |
import os | |
import shortuuid | |
from PIL import Image | |
import numpy as np | |
from torchvision.ops import box_convert | |
from pipeline.tool.scene_text_model import * | |
# import sys | |
# sys.path.append("/home/wcx/wcx/EasyDetect/GroundingDINO") | |
from pipeline.GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate | |
BOX_TRESHOLD = 0.35 # used in detector api. | |
TEXT_TRESHOLD = 0.25 # used in detector api. | |
AREA_THRESHOLD = 0.001 # used to filter out too small object. | |
IOU_THRESHOLD = 0.95 # used to filter the same instance. greater than threshold means the same instance | |
class GroundingDINO: | |
def __init__(self): | |
self.BOX_TRESHOLD = 0.35 | |
self.TEXT_TRESHOLD = 0.25 | |
self.text_rec = MAERec() | |
# load only one time | |
self.model = load_model("pipeline/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", | |
"models/groundingdino_swint_ogc.pth", ) | |
#device='cuda:0') | |
def execute(self, image_path, content, new_path, use_text_rec): | |
IMAGE_PATH = image_path | |
image_source, image = load_image(IMAGE_PATH) | |
if use_text_rec: | |
# 在场景文本中下调boxthreshold | |
boxes, logits, phrases = predict(model=self.model,image=image,caption=content,box_threshold=0.2,text_threshold=self.TEXT_TRESHOLD,device='cuda:0') | |
h, w, _ = image_source.shape | |
torch_boxes = boxes * torch.Tensor([w, h, w, h]) | |
xyxy = box_convert(boxes=torch_boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() | |
normed_xyxy = np.around(np.clip(xyxy / np.array([w, h, w, h]), 0., 1.), 3).tolist() | |
dir_name = IMAGE_PATH.split("/")[-1][:-4] | |
cache_dir = new_path + dir_name | |
os.makedirs(cache_dir, exist_ok=True) | |
res_list = [] | |
for box, norm_box in zip(xyxy, normed_xyxy): | |
# filter out too small object | |
thre = AREA_THRESHOLD | |
if (norm_box[2]-norm_box[0]) * (norm_box[3]-norm_box[1]) < 0.001: | |
continue | |
crop_id = shortuuid.uuid() | |
crop_img = Image.fromarray(image_source).crop(box) | |
crop_path = os.path.join(cache_dir, f"{crop_id}.jpg") | |
crop_img.save(crop_path) | |
_, res = self.text_rec.execute(crop_path) | |
print(res) | |
res_list.append(res) | |
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=res_list) | |
new_id = shortuuid.uuid() | |
new_image_path = os.path.join(cache_dir, f"{new_id}.jpg") | |
cv2.imwrite(new_image_path, annotated_frame) | |
result = {"boxes":normed_xyxy, "logits":logits, "phrases":res_list, "new_path":new_image_path} | |
return result | |
else: | |
new_path = new_path + IMAGE_PATH.split('/')[-1] | |
print(content) | |
boxes, logits, phrases = predict(model=self.model,image=image,caption=content,box_threshold=self.BOX_TRESHOLD,text_threshold=self.TEXT_TRESHOLD,device='cuda:0') | |
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases) | |
cv2.imwrite(new_path, annotated_frame) | |
h, w, _ = image_source.shape | |
torch_boxes = boxes * torch.Tensor([w, h, w, h]) | |
xyxy = box_convert(boxes=torch_boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() | |
normed_xyxy = np.around(np.clip(xyxy / np.array([w, h, w, h]), 0., 1.), 3).tolist() | |
result = {"boxes":normed_xyxy, "logits":logits, "phrases":phrases, "new_path":new_path, "xyxy":xyxy, "image_source":image_source} | |
return result | |
if __name__ == '__main__': | |
config = yaml.load(open("/home/wcx/wcx/GroundingDINO/LVLM/config/config.yaml", "r"), Loader=yaml.FullLoader) | |
t = GroundingDINO(config=config) | |
# /newdisk3/wcx/TextVQA/test_images/fca674d065b0ee2c.jpg | |
# /newdisk3/wcx/TextVQA/test_images/6648410adb1b08cb.jpg | |
image_path = "/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/image.jpg" | |
#input = {"text":{"question":"Describe the image","answer":""},"image":image_path} | |
# res = t.execute(image_path=image_path,content="word.number",new_path="/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/extra/",use_text_rec=True) | |
# print(res) | |
res2 = t.execute(image_path,content="car.man.glasses.coat",new_path="/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/extra/",use_text_rec=False) | |
print(res2) | |
''' | |
dog cat | |
[[0.107, 0.005, 0.56, 0.999], [0.597, 0.066, 1.0, 0.953]] | |
'basketball', 'boy', 'car' | |
[0.741, 0.179, 0.848, 0.285], [0.773, 0.299, 0.98, 0.828], [0.001, 0.304, 0.992, 0.854] | |
'worlld | |
[0.405, 0.504, 0.726, 0.7] | |
''' | |
""" | |
cloud.agricultural exhibit.music.sky.food vendor.sign.street sign.carnival ride | |
/val2014/COCO_val2014_000000029056.jpg | |
""" | |