Add the cleaned dataset
Browse files- data/train_dataset.json +2 -2
- data/valid_dataset.json +2 -2
- run_medclip.sh +3 -3
- src/hybrid_clip/run_hybrid_clip.py +11 -5
- src/hybrid_clip/utils/roco_dataset.ipynb +0 -0
data/train_dataset.json
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:898792e7bb3f1b4b390d35a6e2bad326a1fd5db44e169dd33b02bce5f1d6a4dc
|
3 |
+
size 14451628
|
data/valid_dataset.json
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5de7c2611565819c960f93f007a0015355686fc0c76ab496729474981c441a3d
|
3 |
+
size 1834366
|
run_medclip.sh
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
python src/hybrid_clip/run_hybrid_clip.py \
|
2 |
-
--output_dir ./snapshots/
|
3 |
--text_model_name_or_path="roberta-base" \
|
4 |
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
5 |
--tokenizer_name="roberta-base" \
|
6 |
--train_file="data/train_dataset.json" \
|
7 |
--validation_file="data/valid_dataset.json" \
|
8 |
--do_train --do_eval \
|
9 |
-
--num_train_epochs="40" --max_seq_length
|
10 |
--per_device_train_batch_size="64" \
|
11 |
--per_device_eval_batch_size="64" \
|
12 |
-
--learning_rate="5e-
|
13 |
--overwrite_output_dir \
|
14 |
--preprocessing_num_workers 32 \
|
15 |
# --push_to_hub
|
|
|
1 |
python src/hybrid_clip/run_hybrid_clip.py \
|
2 |
+
--output_dir ./snapshots/vision_augmented \
|
3 |
--text_model_name_or_path="roberta-base" \
|
4 |
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
5 |
--tokenizer_name="roberta-base" \
|
6 |
--train_file="data/train_dataset.json" \
|
7 |
--validation_file="data/valid_dataset.json" \
|
8 |
--do_train --do_eval \
|
9 |
+
--num_train_epochs="40" --max_seq_length 128 \
|
10 |
--per_device_train_batch_size="64" \
|
11 |
--per_device_eval_batch_size="64" \
|
12 |
+
--learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
|
13 |
--overwrite_output_dir \
|
14 |
--preprocessing_num_workers 32 \
|
15 |
# --push_to_hub
|
src/hybrid_clip/run_hybrid_clip.py
CHANGED
@@ -37,6 +37,7 @@ from torchvision.datasets import VisionDataset
|
|
37 |
from torchvision.io import ImageReadMode, read_image
|
38 |
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
39 |
from torchvision.transforms.functional import InterpolationMode
|
|
|
40 |
from tqdm import tqdm
|
41 |
|
42 |
import jax
|
@@ -178,6 +179,9 @@ class Transform(torch.nn.Module):
|
|
178 |
self.transforms = torch.nn.Sequential(
|
179 |
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
180 |
CenterCrop(image_size),
|
|
|
|
|
|
|
181 |
ConvertImageDtype(torch.float),
|
182 |
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
183 |
)
|
@@ -224,10 +228,10 @@ class ImageTextDataset(VisionDataset):
|
|
224 |
self.image_paths = []
|
225 |
|
226 |
for example in examples:
|
227 |
-
self.captions.extend(
|
228 |
-
self.image_paths.
|
229 |
-
|
230 |
-
|
231 |
|
232 |
def _load_image(self, idx: int):
|
233 |
path = self.image_paths[idx]
|
@@ -374,8 +378,10 @@ def main():
|
|
374 |
# Use collate function to tokenizer the text and convert the processed images to numpy
|
375 |
def collate_fn(examples):
|
376 |
pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
|
|
|
377 |
captions = [example[1] for example in examples]
|
378 |
-
inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np"
|
|
|
379 |
|
380 |
batch = {
|
381 |
"pixel_values": pixel_values,
|
|
|
37 |
from torchvision.io import ImageReadMode, read_image
|
38 |
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
39 |
from torchvision.transforms.functional import InterpolationMode
|
40 |
+
from torchvision.transforms.transforms import GaussianBlur, RandomAutocontrast, RandomHorizontalFlip
|
41 |
from tqdm import tqdm
|
42 |
|
43 |
import jax
|
|
|
179 |
self.transforms = torch.nn.Sequential(
|
180 |
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
181 |
CenterCrop(image_size),
|
182 |
+
GaussianBlur(3, sigma=(0.05, 0.2)),
|
183 |
+
RandomAutocontrast(p=0.5),
|
184 |
+
RandomHorizontalFlip(p=0.5),
|
185 |
ConvertImageDtype(torch.float),
|
186 |
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
187 |
)
|
|
|
228 |
self.image_paths = []
|
229 |
|
230 |
for example in examples:
|
231 |
+
# self.captions.extend(example["captions"])
|
232 |
+
# self.image_paths.append(example["image_path"])
|
233 |
+
self.captions.extend(example["captions"][:captions_per_image])
|
234 |
+
self.image_paths.extend([example["image_path"]] * captions_per_image)
|
235 |
|
236 |
def _load_image(self, idx: int):
|
237 |
path = self.image_paths[idx]
|
|
|
378 |
# Use collate function to tokenizer the text and convert the processed images to numpy
|
379 |
def collate_fn(examples):
|
380 |
pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
|
381 |
+
# pixel_values = torch.stack([example[0] for example in examples]).numpy()
|
382 |
captions = [example[1] for example in examples]
|
383 |
+
inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np",
|
384 |
+
truncation=True)
|
385 |
|
386 |
batch = {
|
387 |
"pixel_values": pixel_values,
|
src/hybrid_clip/utils/roco_dataset.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|