lixc commited on
Commit
d9e5c87
1 Parent(s): b96c6e9

extract features by SwAV-ResNet50w2 model

Browse files
Files changed (1) hide show
  1. featureExtractor.py +67 -0
featureExtractor.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as tvt
3
+ import pandas as pd
4
+ import os
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+
8
+ torch.set_num_threads(2)
9
+
10
+ outdir = 'pt_files/train'
11
+ yolo_crop_file = 'image_yolo.txt'
12
+
13
+ def crop(img, x, y, w, h):
14
+ #if not dets:
15
+ # return img
16
+ #x, y, w, h = [float(e) for e in dets.split(',')[0:4]]
17
+
18
+ W, H = img.size
19
+ x1 = x * W - w * W / 2.0
20
+ x2 = x * W + w * W / 2.0
21
+ y1 = y * H - h * H / 2.0
22
+ y2 = y * H + h * H / 2.0
23
+
24
+ return img.crop((x1,y1,x2,y2))
25
+
26
+ is_report_file = lambda s: 'RPT' in s
27
+ get_barcode = lambda s: s.split('/')[-3]
28
+
29
+ CHANNEL = 3
30
+ IMAGE_SIZE = 448
31
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
32
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
33
+ normalize = tvt.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD))
34
+ transform_ops = tvt.Compose([tvt.Resize(IMAGE_SIZE), tvt.CenterCrop(IMAGE_SIZE), tvt.ToTensor(), normalize])
35
+
36
+ model_path = './traced_swav_imagenet_layer2.pt'
37
+
38
+ df = pd.read_csv(yolo_crop_file)
39
+ df.insert(0, 'is_report_file', [is_report_file(s) for s in df.orig])
40
+ df.insert(0, 'patient_barcode', [get_barcode(s) for s in df.orig])
41
+ df = df[df.is_report_file == False]
42
+
43
+ net = torch.jit.load(model_path)
44
+ net = net.cuda()
45
+ net.eval()
46
+
47
+
48
+ for patient_barcode, dfg in tqdm(df.groupby('patient_barcode'), total=len(df.patient_barcode.unique())):
49
+ outfile = f"{outdir}/{patient_barcode}.pt"
50
+ if os.path.exists(outfile):continue
51
+
52
+ N = len(dfg)
53
+ image_tensors = torch.zeros(N, CHANNEL, IMAGE_SIZE, IMAGE_SIZE)
54
+ for i, image_file, x, y, w, h in zip(range(N), dfg.orig, dfg.x, dfg.y, dfg.w, dfg.h):
55
+ with open(image_file, 'rb') as f:
56
+ img = Image.open(f)
57
+ img = img.convert('RGB')
58
+ img = crop(img, x, y, w, h)
59
+ img_tensor = transform_ops(img)
60
+ image_tensors[i] = img_tensor
61
+
62
+ image_tensors = image_tensors.cuda()
63
+ with torch.no_grad():
64
+ features = net(image_tensors).cpu()
65
+
66
+ torch.save(features, outfile)
67
+