File size: 9,358 Bytes
e26e560 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import cv2
import xml.etree.ElementTree as ET
import os,sys
from albumentations import HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, \
RandomRotate90, Transpose, ShiftScaleRotate, Blur, CenterCrop, RandomCrop, \
OpticalDistortion, GridDistortion, HueSaturationValue, \
IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, \
IAAPiecewiseAffine, IAASharpen, IAAEmboss, RandomContrast, \
RandomBrightness, Flip, OneOf, VerticalFlip, Resize, Rotate, Compose
import numpy as np
def pretty_xml(element, indent = '\t', newline = '\n', level=0):
if element: # 判断element是否有子元素
if (element.text is None) or element.text.isspace(): # 如果element的text没有内容
element.text = newline + indent * (level + 1)
else:
element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * (level + 1)
# else: # 此处两行如果把注释去掉,Element的text也会另起一行
# element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * level
temp = list(element) # 将element转成list
for subelement in temp:
if temp.index(subelement) < (len(temp) - 1): # 如果不是list的最后一个元素,说明下一个行是同级别元素的起始,缩进应一致
subelement.tail = newline + indent * (level + 1)
else: # 如果是list的最后一个元素, 说明下一行是母元素的结束,缩进应该少一个
subelement.tail = newline + indent * level
pretty_xml(subelement, indent, newline, level=level + 1) # 对子元素进行递归操作
def insert_object(root, xmin, xmax, ymin, ymax):
obj = ET.Element('object')
obj.tail = '\n'
root.append(obj)
name = ET.Element('name')
name.text = 'Bait'
name.tail = '\n'
obj.append(name)
pose = ET.Element('pose')
pose.text = 'Unspecified'
pose.tail = '\n'
obj.append(pose)
truncated = ET.Element('truncated')
truncated.text = '0'
truncated.tail = '\n'
obj.append(truncated)
difficult = ET.Element('difficult')
difficult.text = '0'
difficult.tail = '\n'
obj.append(difficult)
bndbox = ET.Element('bndbox')
bndbox.tail = '\n'
obj.append(bndbox)
x_min = ET.Element('xmin')
x_min.text = str(xmin)
x_min.tail = '\n'
bndbox.append(x_min)
y_min = ET.Element('ymin')
y_min.text = str(ymin)
y_min.tail = '\n'
bndbox.append(y_min)
x_max = ET.Element('xmax')
x_max.text = str(xmax)
x_max.tail = '\n'
bndbox.append(x_max)
y_max = ET.Element('ymax')
y_max.text = str(ymax)
y_max.tail = '\n'
bndbox.append(y_max)
BOX_COLOR = (255, 0, 0)
TEXT_COLOR = (255, 255, 255)
def visualize_bbox(img, bbox, class_id, class_idx_to_name, color=BOX_COLOR, thickness=2):
x_min, y_min, x_max, y_max = bbox
x_min, x_max, y_min, y_max = int(x_min), int(x_max), int(y_min), int(y_max)
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
class_name = class_idx_to_name[class_id]
((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1)
cv2.putText(img, class_name, (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, 0.35,TEXT_COLOR, lineType=cv2.LINE_AA)
return img
def visualize(annotations, category_id_to_name):
img = annotations['image'].copy()
for idx, bbox in enumerate(annotations['bboxes']):
img = visualize_bbox(img, bbox, annotations['category_id'][idx], category_id_to_name)
cv2.imshow('data_augmentation', img)
cv2.waitKey(0)
def get_aug(aug, min_area=0., min_visibility=0.):
return Compose(aug, bbox_params={'format': 'pascal_voc', 'min_area': min_area, 'min_visibility': min_visibility, 'label_fields': ['category_id']})
category_id_to_name = {0: 'Bait', 1: 'bait'}
aug_ver = get_aug([VerticalFlip(p = 1)]) #垂直方向翻转
aug_hor = get_aug([HorizontalFlip(p=1)]) #水平方向翻转
aug_res = get_aug([Resize(p=1, height=256, width=256)]) #resize
aug_cen = get_aug([CenterCrop(p=1, height=200, width=200)], min_area=4000)
aug_cen = get_aug([CenterCrop(p=1, height=100, width=100)], min_visibility=0.3) # 只返回变换后可见性大于 threshold 的 boxes
aug_ran = get_aug([RandomCrop(p=1, height=100, width=100)])
aug_SCR =get_aug([ShiftScaleRotate(shift_limit=0.0625,
scale_limit=1,
rotate_limit=45, p=1)]) #旋转、裁切
aug_rot = get_aug([Rotate(limit=60, p =1.0)])
aug_list = [aug_ver, aug_hor, aug_rot] #想用哪个,就添加在找个list里
#-------------------------- 读取xml,解析,增强图像,修改box信息,写入xml -----------------------------#
if __name__ == '__main__':
jpgPath = 'images'
xmlPath = 'xml/'
xmls = os.listdir(xmlPath)
for xml in xmls:
xmlName = xml.split('.')[0]
imgName = xmlName + '.jpg'
try:
tree = ET.parse(os.path.join(xmlPath, xml))
root = tree.getroot()
except Exception as e:
print('prase ' + xml + ' failed!')
sys.exit()
else:
image = cv2.imread(os.path.join(jpgPath, imgName))
for width in root.iter('width'):
if int(width.text) == 0:
width.text = str(image.shape[1])
for height in root.iter('height'):
if int(height.text) == 0:
height.text = str(image.shape[0])
tree.write(os.path.join(xmlPath, xmlName + '.xml'))
bboxes = []
for object in root.findall('object'):
for box in object.findall('bndbox'):
x_min = int(box.find('xmin').text)
x_max = int(box.find('xmax').text)
y_min = int(box.find('ymin').text)
y_max = int(box.find('ymax').text)
root.remove(object)
bboxes.append([x_min, y_min, x_max, y_max])
category_id = np.zeros(len(bboxes))
annotations = {'image': image, 'bboxes': bboxes, 'category_id': category_id}
for i, aug in enumerate(aug_list):
aug_type = str(aug).split('(')[1][4:]
augmented = aug(**annotations)
for iter in range(len(augmented['bboxes'])):
x_min, y_min, x_max, y_max = augmented['bboxes'][iter]
x_min, x_max, y_min, y_max = int(x_min), int(x_max), int(y_min), int(y_max)
insert_object(root, x_min, x_max, y_min, y_max)
for filename in root.iter('filename'):
name = filename.text.split('.')[0]
filename.text = name + aug_type + '.jpg'
for path in root.iter('path'):
pathname = path.text.split('.')[0]
path.text = pathname + aug_type + '.jpg'
for width in root.iter('width'):
width.text = str(image.shape[1])
for height in root.iter('height'):
height.text = str(image.shape[0])
if len(augmented['bboxes']) > 0:
cv2.imwrite(os.path.join(jpgPath, xmlName + aug_type +'.jpg'), augmented['image'])
pretty_xml(root)
tree.write(os.path.join(xmlPath, xmlName + aug_type +'.xml'))
for object in root.findall('object'):
root.remove(object)
# #centerCrop
# aug = get_aug([CenterCrop(p=1, height=100, width=100)])
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
# #certerCrop,并限定最小box面积
# aug = get_aug([CenterCrop(p=1, height=200, width=200)], min_area=4000)
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
# # 只返回变换后可见性大于 threshold 的 boxes
# aug = get_aug([CenterCrop(p=1, height=100, width=100)], min_visibility=0.3)
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
# aug = get_aug([RandomCrop(p=1, height=100, width=100)])
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
# #旋转、裁切
# aug =get_aug([ShiftScaleRotate(shift_limit=0.0625,
# scale_limit=1,
# rotate_limit=45, p=1)])
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
# #旋转
# aug = get_aug([Rotate(limit=60, p = 0.7)])
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
#多种增强混合,同时使用
# def augment_flips_color(p=.5):
# return Compose([
# # CLAHE(),
# Transpose(),
# ShiftScaleRotate(shift_limit=0.0625,
# scale_limit=1,
# rotate_limit=45, p=.75),
# # Blur(blur_limit=3),
# # OpticalDistortion(),
# # GridDistortion(),
# # HueSaturationValue()
# ], p=p)
#
# aug = augment_flips_color(p=1)
# augmented = aug(**annotations)
# visualize(augmented, category_id_to_name)
|