Upload batched_inference.py
Browse files- batched_inference.py +104 -112
batched_inference.py
CHANGED
@@ -1,7 +1,4 @@
|
|
1 |
-
import csv
|
2 |
import torch.multiprocessing as multiprocessing
|
3 |
-
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
import torchvision.transforms as transforms
|
6 |
from torch import autocast
|
7 |
from torch.utils.data import Dataset, DataLoader
|
@@ -9,79 +6,84 @@ from PIL import Image
|
|
9 |
import torch
|
10 |
from torchvision.transforms import InterpolationMode
|
11 |
from tqdm import tqdm
|
12 |
-
import random
|
13 |
import json
|
|
|
14 |
|
15 |
torch.backends.cuda.matmul.allow_tf32 = True
|
16 |
torch.backends.cudnn.allow_tf32 = True
|
17 |
-
|
18 |
torch.autograd.set_detect_anomaly(False)
|
19 |
-
|
20 |
torch.autograd.profiler.emit_nvtx(enabled=False)
|
21 |
torch.autograd.profiler.profile(enabled=False)
|
22 |
torch.backends.cudnn.benchmark = True
|
23 |
|
|
|
24 |
class ImageDataset(Dataset):
|
25 |
-
def __init__(self,
|
26 |
-
|
27 |
-
self.
|
28 |
-
self.
|
29 |
-
|
30 |
-
self.
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def __len__(self):
|
66 |
-
return len(self.
|
67 |
|
68 |
def __getitem__(self, index):
|
69 |
-
image = Image.open(self.
|
70 |
-
ratio = image.height/image.width
|
71 |
if ratio > 2.0 or ratio < 0.5:
|
72 |
image = self.thin_transform(image)
|
73 |
else:
|
74 |
image = self.normal_transform(image)
|
75 |
|
76 |
-
|
77 |
return {
|
78 |
'image': image,
|
79 |
-
"image_name": self.all_image_names[index]
|
|
|
80 |
}
|
81 |
|
82 |
|
83 |
-
def prepare_model():
|
84 |
-
model = torch.load(
|
85 |
model.to(memory_format=torch.channels_last)
|
86 |
model = model.eval()
|
87 |
return model
|
@@ -94,20 +96,19 @@ def train(tagging_is_running, model, dataloader, train_data, output_queue):
|
|
94 |
|
95 |
with torch.no_grad():
|
96 |
for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)):
|
97 |
-
|
98 |
-
data, image_names = data['image'].to("cuda"), data["image_name"]
|
99 |
with autocast(device_type='cuda', dtype=torch.bfloat16):
|
100 |
-
outputs = model(
|
101 |
|
102 |
probabilities = torch.nn.functional.sigmoid(outputs)
|
103 |
-
output_queue.put((probabilities.to("cpu"),
|
104 |
|
105 |
counter += 1
|
106 |
_ = tagging_is_running.get()
|
107 |
print("Tagging finished!")
|
108 |
|
109 |
|
110 |
-
def tag_writer(tagging_is_running, output_queue,
|
111 |
with open("tags.json", "r") as file:
|
112 |
tags = json.load(file)
|
113 |
allowed_tags = sorted(tags)
|
@@ -116,78 +117,69 @@ def tag_writer(tagging_is_running, output_queue, output_file_name):
|
|
116 |
tag_count = len(allowed_tags)
|
117 |
assert tag_count == 7704, f"The length of loss scaling factor is not correct. Correct: 7704, current: {tag_count}"
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
if
|
131 |
-
tag
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
-
|
143 |
-
def set_seed(seed: int = 42) -> None:
|
144 |
-
np.random.seed(seed)
|
145 |
-
random.seed(seed)
|
146 |
-
torch.manual_seed(seed)
|
147 |
-
torch.cuda.manual_seed(seed)
|
148 |
-
# When running on the CuDNN backend, two further options must be set
|
149 |
-
torch.backends.cudnn.deterministic = True
|
150 |
-
torch.backends.cudnn.benchmark = False
|
151 |
-
# Set a fixed value for the hash seed
|
152 |
-
print(f"Random seed set as {seed}")
|
153 |
-
|
154 |
-
|
155 |
-
if __name__ == "__main__":
|
156 |
-
steps = 0
|
157 |
-
output_file_name = "your_file.csv"
|
158 |
-
set_seed()
|
159 |
multiprocessing.set_start_method('spawn')
|
160 |
output_queue = multiprocessing.Queue()
|
161 |
tagging_is_running = multiprocessing.Queue(maxsize=5)
|
162 |
tagging_is_running.put("Running!")
|
163 |
|
164 |
# initialize the computation device
|
165 |
-
if torch.cuda.is_available():
|
166 |
-
device = torch.device('cuda')
|
167 |
-
else:
|
168 |
raise RuntimeError("CUDA is not available!")
|
169 |
|
170 |
-
model = prepare_model().to("cuda")
|
171 |
-
batch_size = 128
|
172 |
-
|
173 |
|
174 |
# read the training csv file
|
175 |
-
train_csv = pd.read_csv('/path/to/a/list/of/files/and/their/extensions.csv')
|
176 |
# train dataset
|
177 |
-
|
178 |
-
train_csv, train=True
|
179 |
-
)
|
180 |
|
181 |
-
|
182 |
-
|
183 |
batch_size=batch_size,
|
184 |
shuffle=False,
|
185 |
-
num_workers=6,
|
186 |
-
pin_memory=True
|
|
|
187 |
)
|
188 |
-
process_writer = multiprocessing.Process(target=tag_writer,
|
|
|
189 |
process_writer.start()
|
190 |
-
process_tagger = multiprocessing.Process(target=train,
|
|
|
191 |
process_tagger.start()
|
192 |
process_writer.join()
|
193 |
process_tagger.join()
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch.multiprocessing as multiprocessing
|
|
|
|
|
2 |
import torchvision.transforms as transforms
|
3 |
from torch import autocast
|
4 |
from torch.utils.data import Dataset, DataLoader
|
|
|
6 |
import torch
|
7 |
from torchvision.transforms import InterpolationMode
|
8 |
from tqdm import tqdm
|
|
|
9 |
import json
|
10 |
+
import os
|
11 |
|
12 |
torch.backends.cuda.matmul.allow_tf32 = True
|
13 |
torch.backends.cudnn.allow_tf32 = True
|
|
|
14 |
torch.autograd.set_detect_anomaly(False)
|
|
|
15 |
torch.autograd.profiler.emit_nvtx(enabled=False)
|
16 |
torch.autograd.profiler.profile(enabled=False)
|
17 |
torch.backends.cudnn.benchmark = True
|
18 |
|
19 |
+
|
20 |
class ImageDataset(Dataset):
|
21 |
+
def __init__(self, image_folder_path, allowed_extensions):
|
22 |
+
self.allowed_extensions = allowed_extensions
|
23 |
+
self.all_image_paths, self.all_image_names, self.image_base_paths = self.get_image_paths(image_folder_path)
|
24 |
+
self.train_size = len(self.all_image_paths)
|
25 |
+
print(f"Number of images to be tagged: {self.train_size}")
|
26 |
+
self.thin_transform = transforms.Compose([
|
27 |
+
transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
|
28 |
+
transforms.CenterCrop(224),
|
29 |
+
transforms.ToTensor(),
|
30 |
+
transforms.Normalize(mean=[
|
31 |
+
0.48145466,
|
32 |
+
0.4578275,
|
33 |
+
0.40821073
|
34 |
+
], std=[
|
35 |
+
0.26862954,
|
36 |
+
0.26130258,
|
37 |
+
0.27577711
|
38 |
+
]) # Normalize image
|
39 |
+
])
|
40 |
+
self.normal_transform = transforms.Compose([
|
41 |
+
transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
|
42 |
+
transforms.ToTensor(),
|
43 |
+
transforms.Normalize(mean=[
|
44 |
+
0.48145466,
|
45 |
+
0.4578275,
|
46 |
+
0.40821073
|
47 |
+
], std=[
|
48 |
+
0.26862954,
|
49 |
+
0.26130258,
|
50 |
+
0.27577711
|
51 |
+
]) # Normalize image
|
52 |
+
|
53 |
+
])
|
54 |
+
|
55 |
+
def get_image_paths(self, folder_path):
|
56 |
+
image_paths = []
|
57 |
+
image_file_names = []
|
58 |
+
image_base_paths = []
|
59 |
+
for root, dirs, files in os.walk(folder_path):
|
60 |
+
for file in files:
|
61 |
+
if file.lower().split(".")[-1] in self.allowed_extensions:
|
62 |
+
image_paths.append((os.path.abspath(os.path.join(root, file))))
|
63 |
+
image_file_names.append(file.split(".")[0])
|
64 |
+
image_base_paths.append(root)
|
65 |
+
return image_paths, image_file_names, image_base_paths
|
66 |
|
67 |
def __len__(self):
|
68 |
+
return len(self.all_image_paths)
|
69 |
|
70 |
def __getitem__(self, index):
|
71 |
+
image = Image.open(self.all_image_paths[index]).convert("RGB")
|
72 |
+
ratio = image.height / image.width
|
73 |
if ratio > 2.0 or ratio < 0.5:
|
74 |
image = self.thin_transform(image)
|
75 |
else:
|
76 |
image = self.normal_transform(image)
|
77 |
|
|
|
78 |
return {
|
79 |
'image': image,
|
80 |
+
"image_name": self.all_image_names[index],
|
81 |
+
"image_root": self.image_base_paths[index]
|
82 |
}
|
83 |
|
84 |
|
85 |
+
def prepare_model(model_path: str):
|
86 |
+
model = torch.load(model_path)
|
87 |
model.to(memory_format=torch.channels_last)
|
88 |
model = model.eval()
|
89 |
return model
|
|
|
96 |
|
97 |
with torch.no_grad():
|
98 |
for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)):
|
99 |
+
this_data = data['image'].to("cuda")
|
|
|
100 |
with autocast(device_type='cuda', dtype=torch.bfloat16):
|
101 |
+
outputs = model(this_data)
|
102 |
|
103 |
probabilities = torch.nn.functional.sigmoid(outputs)
|
104 |
+
output_queue.put((probabilities.to("cpu"), data["image_name"], data["image_root"]))
|
105 |
|
106 |
counter += 1
|
107 |
_ = tagging_is_running.get()
|
108 |
print("Tagging finished!")
|
109 |
|
110 |
|
111 |
+
def tag_writer(tagging_is_running, output_queue, threshold):
|
112 |
with open("tags.json", "r") as file:
|
113 |
tags = json.load(file)
|
114 |
allowed_tags = sorted(tags)
|
|
|
117 |
tag_count = len(allowed_tags)
|
118 |
assert tag_count == 7704, f"The length of loss scaling factor is not correct. Correct: 7704, current: {tag_count}"
|
119 |
|
120 |
+
while not (tagging_is_running.qsize() > 0 and output_queue.qsize() > 0):
|
121 |
+
tag_probabilities, image_names, image_roots = output_queue.get()
|
122 |
+
tag_probabilities = tag_probabilities.tolist()
|
123 |
+
|
124 |
+
for per_image_tag_probabilities, image_name, image_root in zip(tag_probabilities, image_names, image_roots,
|
125 |
+
strict=True):
|
126 |
+
this_image_tags = []
|
127 |
+
this_image_tag_probabilities = []
|
128 |
+
for index, per_tag_probability in enumerate(per_image_tag_probabilities):
|
129 |
+
if per_tag_probability > threshold:
|
130 |
+
tag = allowed_tags[index]
|
131 |
+
if "placeholder" not in tag:
|
132 |
+
this_image_tags.append(tag)
|
133 |
+
this_image_tag_probabilities.append(str(int(round(per_tag_probability, 3) * 1000)))
|
134 |
+
output_file = os.path.join(image_root, os.path.splitext(image_name)[0] + ".txt")
|
135 |
+
with open(output_file, "w", encoding="utf-8") as this_output:
|
136 |
+
this_output.write(" ".join(this_image_tags))
|
137 |
+
this_output.write("\n")
|
138 |
+
this_output.write(" ".join(this_image_tag_probabilities))
|
139 |
+
|
140 |
+
|
141 |
+
def main():
|
142 |
+
image_folder_path = "/path/to/your/folder/"
|
143 |
+
# all images should be in this folder and/or its subfolders.
|
144 |
+
# I will generate a text file for every image.
|
145 |
+
model_path = "/path/to/your/model.pth"
|
146 |
+
allowed_extensions = {"jpg", "jpeg", "png", "webp"}
|
147 |
+
batch_size = 64
|
148 |
+
# if you have a 24GB card, you can try 256
|
149 |
+
threshold = 0.3
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
multiprocessing.set_start_method('spawn')
|
152 |
output_queue = multiprocessing.Queue()
|
153 |
tagging_is_running = multiprocessing.Queue(maxsize=5)
|
154 |
tagging_is_running.put("Running!")
|
155 |
|
156 |
# initialize the computation device
|
157 |
+
if not torch.cuda.is_available():
|
|
|
|
|
158 |
raise RuntimeError("CUDA is not available!")
|
159 |
|
160 |
+
model = prepare_model(model_path).to("cuda")
|
|
|
|
|
161 |
|
162 |
# read the training csv file
|
|
|
163 |
# train dataset
|
164 |
+
dataset = ImageDataset(image_folder_path, allowed_extensions)
|
|
|
|
|
165 |
|
166 |
+
batched_loader = DataLoader(
|
167 |
+
dataset,
|
168 |
batch_size=batch_size,
|
169 |
shuffle=False,
|
170 |
+
num_workers=6, # if you have a big batch size, a good cpu, and enough cpu memory, try 12
|
171 |
+
pin_memory=True,
|
172 |
+
drop_last=False,
|
173 |
)
|
174 |
+
process_writer = multiprocessing.Process(target=tag_writer,
|
175 |
+
args=(tagging_is_running, output_queue, threshold))
|
176 |
process_writer.start()
|
177 |
+
process_tagger = multiprocessing.Process(target=train,
|
178 |
+
args=(tagging_is_running, model, batched_loader, dataset, output_queue,))
|
179 |
process_tagger.start()
|
180 |
process_writer.join()
|
181 |
process_tagger.join()
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
main()
|