Spaces:
Runtime error
Runtime error
File size: 4,006 Bytes
f9827f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import argparse
import pickle
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import inception_v3, Inception3
import numpy as np
from tqdm import tqdm
from inception import InceptionV3
from dataset import MultiResolutionDataset
class Inception3Feature(Inception3):
def forward(self, x):
if x.shape[2] != 299 or x.shape[3] != 299:
x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True)
x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
x = self.Mixed_5b(x) # 35 x 35 x 192
x = self.Mixed_5c(x) # 35 x 35 x 256
x = self.Mixed_5d(x) # 35 x 35 x 288
x = self.Mixed_6a(x) # 35 x 35 x 288
x = self.Mixed_6b(x) # 17 x 17 x 768
x = self.Mixed_6c(x) # 17 x 17 x 768
x = self.Mixed_6d(x) # 17 x 17 x 768
x = self.Mixed_6e(x) # 17 x 17 x 768
x = self.Mixed_7a(x) # 17 x 17 x 768
x = self.Mixed_7b(x) # 8 x 8 x 1280
x = self.Mixed_7c(x) # 8 x 8 x 2048
x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
def load_patched_inception_v3():
# inception = inception_v3(pretrained=True)
# inception_feat = Inception3Feature()
# inception_feat.load_state_dict(inception.state_dict())
inception_feat = InceptionV3([3], normalize_input=False)
return inception_feat
@torch.no_grad()
def extract_features(loader, inception, device):
pbar = tqdm(loader)
feature_list = []
for img in pbar:
img = img.to(device)
feature = inception(img)[0].view(img.shape[0], -1)
feature_list.append(feature.to("cpu"))
features = torch.cat(feature_list, 0)
return features
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(
description="Calculate Inception v3 features for datasets"
)
parser.add_argument(
"--size",
type=int,
default=256,
help="image sizes used for embedding calculation",
)
parser.add_argument(
"--batch", default=64, type=int, help="batch size for inception networks"
)
parser.add_argument(
"--n_sample",
type=int,
default=50000,
help="number of samples used for embedding calculation",
)
parser.add_argument(
"--flip", action="store_true", help="apply random flipping to real images"
)
parser.add_argument("path", metavar="PATH", help="path to datset lmdb file")
args = parser.parse_args()
inception = load_patched_inception_v3()
inception = nn.DataParallel(inception).eval().to(device)
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)
loader = DataLoader(dset, batch_size=args.batch, num_workers=4)
features = extract_features(loader, inception, device).numpy()
features = features[: args.n_sample]
print(f"extracted {features.shape[0]} features")
mean = np.mean(features, 0)
cov = np.cov(features, rowvar=False)
name = os.path.splitext(os.path.basename(args.path))[0]
with open(f"inception_{name}.pkl", "wb") as f:
pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f)
|