|
import os
|
|
import json
|
|
import pandas as pd
|
|
|
|
|
|
class VisASolver(object):
|
|
CLSNAMES = [
|
|
'candle', 'capsules', 'cashew', 'chewinggum', 'fryum',
|
|
'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3',
|
|
'pcb4', 'pipe_fryum',
|
|
]
|
|
|
|
def __init__(self, root='data/visa'):
|
|
self.root = root
|
|
self.meta_path = f'{root}/meta.json'
|
|
self.phases = ['train', 'test']
|
|
self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0)
|
|
|
|
def run(self):
|
|
columns = self.csv_data.columns
|
|
info = {phase: {} for phase in self.phases}
|
|
anomaly_samples = 0
|
|
normal_samples = 0
|
|
for cls_name in self.CLSNAMES:
|
|
cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name]
|
|
for phase in self.phases:
|
|
cls_info = []
|
|
cls_data_phase = cls_data[cls_data[columns[1]] == phase]
|
|
cls_data_phase.index = list(range(len(cls_data_phase)))
|
|
for idx in range(cls_data_phase.shape[0]):
|
|
data = cls_data_phase.loc[idx]
|
|
is_abnormal = True if data[2] == 'anomaly' else False
|
|
info_img = dict(
|
|
img_path=data[3],
|
|
mask_path=data[4] if is_abnormal else '',
|
|
cls_name=cls_name,
|
|
specie_name='',
|
|
anomaly=1 if is_abnormal else 0,
|
|
)
|
|
cls_info.append(info_img)
|
|
if phase == 'test':
|
|
if is_abnormal:
|
|
anomaly_samples = anomaly_samples + 1
|
|
else:
|
|
normal_samples = normal_samples + 1
|
|
info[phase][cls_name] = cls_info
|
|
with open(self.meta_path, 'w') as f:
|
|
f.write(json.dumps(info, indent=4) + "\n")
|
|
print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
runner = VisASolver(root='/remote-home/iot_zhouqihang/data/Visa')
|
|
runner.run()
|
|
|