ZJF-Thunder
添加文件
e26e560
import cv2
import xml.etree.ElementTree as ET
import os,sys
from albumentations import HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, \
RandomRotate90, Transpose, ShiftScaleRotate, Blur, CenterCrop, RandomCrop, \
OpticalDistortion, GridDistortion, HueSaturationValue, \
IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, \
IAAPiecewiseAffine, IAASharpen, IAAEmboss, RandomContrast, \
RandomBrightness, Flip, OneOf, VerticalFlip, Resize, Rotate, Compose
import numpy as np
def pretty_xml(element, indent = '\t', newline = '\n', level=0):
if element: # 判断element是否有子元素
if (element.text is None) or element.text.isspace(): # 如果element的text没有内容
element.text = newline + indent * (level + 1)
else:
element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * (level + 1)
# else: # 此处两行如果把注释去掉,Element的text也会另起一行
# element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * level
temp = list(element) # 将element转成list
for subelement in temp:
if temp.index(subelement) < (len(temp) - 1): # 如果不是list的最后一个元素,说明下一个行是同级别元素的起始,缩进应一致
subelement.tail = newline + indent * (level + 1)
else: # 如果是list的最后一个元素, 说明下一行是母元素的结束,缩进应该少一个
subelement.tail = newline + indent * level
pretty_xml(subelement, indent, newline, level=level + 1) # 对子元素进行递归操作
def insert_object(root, xmin, xmax, ymin, ymax):
obj = ET.Element('object')
obj.tail = '\n'
root.append(obj)
name = ET.Element('name')
name.text = 'Bait'
name.tail = '\n'
obj.append(name)
pose = ET.Element('pose')
pose.text = 'Unspecified'
pose.tail = '\n'
obj.append(pose)
truncated = ET.Element('truncated')
truncated.text = '0'
truncated.tail = '\n'
obj.append(truncated)
difficult = ET.Element('difficult')
difficult.text = '0'
difficult.tail = '\n'
obj.append(difficult)
bndbox = ET.Element('bndbox')
bndbox.tail = '\n'
obj.append(bndbox)
x_min = ET.Element('xmin')
x_min.text = str(xmin)
x_min.tail = '\n'
bndbox.append(x_min)
y_min = ET.Element('ymin')
y_min.text = str(ymin)
y_min.tail = '\n'
bndbox.append(y_min)
x_max = ET.Element('xmax')
x_max.text = str(xmax)
x_max.tail = '\n'
bndbox.append(x_max)
y_max = ET.Element('ymax')
y_max.text = str(ymax)
y_max.tail = '\n'
bndbox.append(y_max)
BOX_COLOR = (255, 0, 0)
TEXT_COLOR = (255, 255, 255)
def visualize_bbox(img, bbox, class_id, class_idx_to_name, color=BOX_COLOR, thickness=2):
x_min, y_min, x_max, y_max = bbox
x_min, x_max, y_min, y_max = int(x_min), int(x_max), int(y_min), int(y_max)
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
class_name = class_idx_to_name[class_id]
((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1)
cv2.putText(img, class_name, (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, 0.35,TEXT_COLOR, lineType=cv2.LINE_AA)
return img
def visualize(annotations, category_id_to_name):
img = annotations['image'].copy()
for idx, bbox in enumerate(annotations['bboxes']):
img = visualize_bbox(img, bbox, annotations['category_id'][idx], category_id_to_name)
cv2.imshow('data_augmentation', img)
cv2.waitKey(0)
def get_aug(aug, min_area=0., min_visibility=0.):
return Compose(aug, bbox_params={'format': 'pascal_voc', 'min_area': min_area, 'min_visibility': min_visibility, 'label_fields': ['category_id']})
category_id_to_name = {0: 'Bait', 1: 'bait'}
aug_ver = get_aug([VerticalFlip(p = 1)]) #垂直方向翻转
aug_hor = get_aug([HorizontalFlip(p=1)]) #水平方向翻转
aug_res = get_aug([Resize(p=1, height=256, width=256)]) #resize
aug_cen = get_aug([CenterCrop(p=1, height=200, width=200)], min_area=4000)
aug_cen = get_aug([CenterCrop(p=1, height=100, width=100)], min_visibility=0.3) # 只返回变换后可见性大于 threshold 的 boxes
aug_ran = get_aug([RandomCrop(p=1, height=100, width=100)])
aug_SCR =get_aug([ShiftScaleRotate(shift_limit=0.0625,
scale_limit=1,
rotate_limit=45, p=1)]) #旋转、裁切
aug_rot = get_aug([Rotate(limit=60, p =1.0)])
aug_list = [aug_ver, aug_hor, aug_rot] #想用哪个,就添加在找个list里
#-------------------------- 读取xml,解析,增强图像,修改box信息,写入xml -----------------------------#
if __name__ == '__main__':
jpgPath = 'images'
xmlPath = 'xml/'
xmls = os.listdir(xmlPath)
for xml in xmls:
xmlName = xml.split('.')[0]
imgName = xmlName + '.jpg'
try:
tree = ET.parse(os.path.join(xmlPath, xml))
root = tree.getroot()
except Exception as e:
print('prase ' + xml + ' failed!')
sys.exit()
else:
image = cv2.imread(os.path.join(jpgPath, imgName))
for width in root.iter('width'):
if int(width.text) == 0:
width.text = str(image.shape[1])
for height in root.iter('height'):
if int(height.text) == 0:
height.text = str(image.shape[0])
tree.write(os.path.join(xmlPath, xmlName + '.xml'))
bboxes = []
for object in root.findall('object'):
for box in object.findall('bndbox'):
x_min = int(box.find('xmin').text)
x_max = int(box.find('xmax').text)
y_min = int(box.find('ymin').text)
y_max = int(box.find('ymax').text)
root.remove(object)
bboxes.append([x_min, y_min, x_max, y_max])
category_id = np.zeros(len(bboxes))
annotations = {'image': image, 'bboxes': bboxes, 'category_id': category_id}
for i, aug in enumerate(aug_list):
aug_type = str(aug).split('(')[1][4:]
augmented = aug(**annotations)
for iter in range(len(augmented['bboxes'])):
x_min, y_min, x_max, y_max = augmented['bboxes'][iter]
x_min, x_max, y_min, y_max = int(x_min), int(x_max), int(y_min), int(y_max)
insert_object(root, x_min, x_max, y_min, y_max)
for filename in root.iter('filename'):
name = filename.text.split('.')[0]
filename.text = name + aug_type + '.jpg'
for path in root.iter('path'):
pathname = path.text.split('.')[0]
path.text = pathname + aug_type + '.jpg'
for width in root.iter('width'):
width.text = str(image.shape[1])
for height in root.iter('height'):
height.text = str(image.shape[0])
if len(augmented['bboxes']) > 0:
cv2.imwrite(os.path.join(jpgPath, xmlName + aug_type +'.jpg'), augmented['image'])
pretty_xml(root)
tree.write(os.path.join(xmlPath, xmlName + aug_type +'.xml'))
for object in root.findall('object'):
root.remove(object)
# #centerCrop
# aug = get_aug([CenterCrop(p=1, height=100, width=100)])
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
# #certerCrop,并限定最小box面积
# aug = get_aug([CenterCrop(p=1, height=200, width=200)], min_area=4000)
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
# # 只返回变换后可见性大于 threshold 的 boxes
# aug = get_aug([CenterCrop(p=1, height=100, width=100)], min_visibility=0.3)
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
# aug = get_aug([RandomCrop(p=1, height=100, width=100)])
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
# #旋转、裁切
# aug =get_aug([ShiftScaleRotate(shift_limit=0.0625,
# scale_limit=1,
# rotate_limit=45, p=1)])
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
# #旋转
# aug = get_aug([Rotate(limit=60, p = 0.7)])
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
#多种增强混合,同时使用
# def augment_flips_color(p=.5):
# return Compose([
# # CLAHE(),
# Transpose(),
# ShiftScaleRotate(shift_limit=0.0625,
# scale_limit=1,
# rotate_limit=45, p=.75),
# # Blur(blur_limit=3),
# # OpticalDistortion(),
# # GridDistortion(),
# # HueSaturationValue()
# ], p=p)
#
# aug = augment_flips_color(p=1)
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)