Caoyunkang's picture
first commit
a25563f verified
import os
import json
import pandas as pd
import random
from dataset import VISA_ROOT
class VisASolver(object):
CLSNAMES = [
'candle', 'capsules', 'cashew', 'chewinggum', 'fryum',
'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3',
'pcb4', 'pipe_fryum',
]
def __init__(self, root=VISA_ROOT, train_ratio=0.5):
self.root = root
self.meta_path = f'{root}/meta.json'
self.phases = ['train', 'test']
self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0)
self.train_ratio = train_ratio
def run(self):
self.generate_meta_info()
def generate_meta_info(self):
columns = self.csv_data.columns # [object, split, label, image, mask]
info = {phase: {} for phase in self.phases}
for cls_name in self.CLSNAMES:
cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name]
for phase in self.phases:
cls_info = []
cls_data_phase = cls_data[cls_data[columns[1]] == phase]
cls_data_phase.index = list(range(len(cls_data_phase)))
for idx in range(cls_data_phase.shape[0]):
data = cls_data_phase.loc[idx]
is_abnormal = True if data[2] == 'anomaly' else False
info_img = dict(
img_path=data[3],
mask_path=data[4] if is_abnormal else '',
cls_name=cls_name,
specie_name='',
anomaly=1 if is_abnormal else 0,
)
cls_info.append(info_img)
info[phase][cls_name] = cls_info
with open(self.meta_path, 'w') as f:
f.write(json.dumps(info, indent=4) + "\n")
if __name__ == '__main__':
runner = VisASolver(root=VISA_ROOT)
runner.run()