lixc
commited on
Commit
•
d9e5c87
1
Parent(s):
b96c6e9
extract features by SwAV-ResNet50w2 model
Browse files- featureExtractor.py +67 -0
featureExtractor.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as tvt
|
3 |
+
import pandas as pd
|
4 |
+
import os
|
5 |
+
from tqdm import tqdm
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
torch.set_num_threads(2)
|
9 |
+
|
10 |
+
outdir = 'pt_files/train'
|
11 |
+
yolo_crop_file = 'image_yolo.txt'
|
12 |
+
|
13 |
+
def crop(img, x, y, w, h):
|
14 |
+
#if not dets:
|
15 |
+
# return img
|
16 |
+
#x, y, w, h = [float(e) for e in dets.split(',')[0:4]]
|
17 |
+
|
18 |
+
W, H = img.size
|
19 |
+
x1 = x * W - w * W / 2.0
|
20 |
+
x2 = x * W + w * W / 2.0
|
21 |
+
y1 = y * H - h * H / 2.0
|
22 |
+
y2 = y * H + h * H / 2.0
|
23 |
+
|
24 |
+
return img.crop((x1,y1,x2,y2))
|
25 |
+
|
26 |
+
is_report_file = lambda s: 'RPT' in s
|
27 |
+
get_barcode = lambda s: s.split('/')[-3]
|
28 |
+
|
29 |
+
CHANNEL = 3
|
30 |
+
IMAGE_SIZE = 448
|
31 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
32 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
33 |
+
normalize = tvt.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD))
|
34 |
+
transform_ops = tvt.Compose([tvt.Resize(IMAGE_SIZE), tvt.CenterCrop(IMAGE_SIZE), tvt.ToTensor(), normalize])
|
35 |
+
|
36 |
+
model_path = './traced_swav_imagenet_layer2.pt'
|
37 |
+
|
38 |
+
df = pd.read_csv(yolo_crop_file)
|
39 |
+
df.insert(0, 'is_report_file', [is_report_file(s) for s in df.orig])
|
40 |
+
df.insert(0, 'patient_barcode', [get_barcode(s) for s in df.orig])
|
41 |
+
df = df[df.is_report_file == False]
|
42 |
+
|
43 |
+
net = torch.jit.load(model_path)
|
44 |
+
net = net.cuda()
|
45 |
+
net.eval()
|
46 |
+
|
47 |
+
|
48 |
+
for patient_barcode, dfg in tqdm(df.groupby('patient_barcode'), total=len(df.patient_barcode.unique())):
|
49 |
+
outfile = f"{outdir}/{patient_barcode}.pt"
|
50 |
+
if os.path.exists(outfile):continue
|
51 |
+
|
52 |
+
N = len(dfg)
|
53 |
+
image_tensors = torch.zeros(N, CHANNEL, IMAGE_SIZE, IMAGE_SIZE)
|
54 |
+
for i, image_file, x, y, w, h in zip(range(N), dfg.orig, dfg.x, dfg.y, dfg.w, dfg.h):
|
55 |
+
with open(image_file, 'rb') as f:
|
56 |
+
img = Image.open(f)
|
57 |
+
img = img.convert('RGB')
|
58 |
+
img = crop(img, x, y, w, h)
|
59 |
+
img_tensor = transform_ops(img)
|
60 |
+
image_tensors[i] = img_tensor
|
61 |
+
|
62 |
+
image_tensors = image_tensors.cuda()
|
63 |
+
with torch.no_grad():
|
64 |
+
features = net(image_tensors).cpu()
|
65 |
+
|
66 |
+
torch.save(features, outfile)
|
67 |
+
|