kaushalya commited on
Commit
2988bf7
·
1 Parent(s): 1cff620

Add the cleaned dataset

Browse files
data/train_dataset.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b6f8f9ecea3f4c6f8196f194510159fccde43ee7f2192b259a11d6bc9ad684cb
3
- size 13426560
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:898792e7bb3f1b4b390d35a6e2bad326a1fd5db44e169dd33b02bce5f1d6a4dc
3
+ size 14451628
data/valid_dataset.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7dbb940f0dee7cb4a85959dc6018aafc824a988b46e3ae8ca2fea6500251ee0a
3
- size 4132661
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5de7c2611565819c960f93f007a0015355686fc0c76ab496729474981c441a3d
3
+ size 1834366
run_medclip.sh CHANGED
@@ -1,15 +1,15 @@
1
  python src/hybrid_clip/run_hybrid_clip.py \
2
- --output_dir ./snapshots/final \
3
  --text_model_name_or_path="roberta-base" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
  --tokenizer_name="roberta-base" \
6
  --train_file="data/train_dataset.json" \
7
  --validation_file="data/valid_dataset.json" \
8
  --do_train --do_eval \
9
- --num_train_epochs="40" --max_seq_length 96 \
10
  --per_device_train_batch_size="64" \
11
  --per_device_eval_batch_size="64" \
12
- --learning_rate="5e-4" --warmup_steps="0" --weight_decay 0.1 \
13
  --overwrite_output_dir \
14
  --preprocessing_num_workers 32 \
15
  # --push_to_hub
 
1
  python src/hybrid_clip/run_hybrid_clip.py \
2
+ --output_dir ./snapshots/vision_augmented \
3
  --text_model_name_or_path="roberta-base" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
  --tokenizer_name="roberta-base" \
6
  --train_file="data/train_dataset.json" \
7
  --validation_file="data/valid_dataset.json" \
8
  --do_train --do_eval \
9
+ --num_train_epochs="40" --max_seq_length 128 \
10
  --per_device_train_batch_size="64" \
11
  --per_device_eval_batch_size="64" \
12
+ --learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
13
  --overwrite_output_dir \
14
  --preprocessing_num_workers 32 \
15
  # --push_to_hub
src/hybrid_clip/run_hybrid_clip.py CHANGED
@@ -37,6 +37,7 @@ from torchvision.datasets import VisionDataset
37
  from torchvision.io import ImageReadMode, read_image
38
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
39
  from torchvision.transforms.functional import InterpolationMode
 
40
  from tqdm import tqdm
41
 
42
  import jax
@@ -178,6 +179,9 @@ class Transform(torch.nn.Module):
178
  self.transforms = torch.nn.Sequential(
179
  Resize([image_size], interpolation=InterpolationMode.BICUBIC),
180
  CenterCrop(image_size),
 
 
 
181
  ConvertImageDtype(torch.float),
182
  Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
183
  )
@@ -224,10 +228,10 @@ class ImageTextDataset(VisionDataset):
224
  self.image_paths = []
225
 
226
  for example in examples:
227
- self.captions.extend([example["captions"]])
228
- self.image_paths.extend([example["image_path"]])
229
- # self.captions.extend(example["captions"][:captions_per_image])
230
- # self.image_paths.extend([example["image_path"]] * captions_per_image)
231
 
232
  def _load_image(self, idx: int):
233
  path = self.image_paths[idx]
@@ -374,8 +378,10 @@ def main():
374
  # Use collate function to tokenizer the text and convert the processed images to numpy
375
  def collate_fn(examples):
376
  pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
 
377
  captions = [example[1] for example in examples]
378
- inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np")
 
379
 
380
  batch = {
381
  "pixel_values": pixel_values,
 
37
  from torchvision.io import ImageReadMode, read_image
38
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
39
  from torchvision.transforms.functional import InterpolationMode
40
+ from torchvision.transforms.transforms import GaussianBlur, RandomAutocontrast, RandomHorizontalFlip
41
  from tqdm import tqdm
42
 
43
  import jax
 
179
  self.transforms = torch.nn.Sequential(
180
  Resize([image_size], interpolation=InterpolationMode.BICUBIC),
181
  CenterCrop(image_size),
182
+ GaussianBlur(3, sigma=(0.05, 0.2)),
183
+ RandomAutocontrast(p=0.5),
184
+ RandomHorizontalFlip(p=0.5),
185
  ConvertImageDtype(torch.float),
186
  Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
187
  )
 
228
  self.image_paths = []
229
 
230
  for example in examples:
231
+ # self.captions.extend(example["captions"])
232
+ # self.image_paths.append(example["image_path"])
233
+ self.captions.extend(example["captions"][:captions_per_image])
234
+ self.image_paths.extend([example["image_path"]] * captions_per_image)
235
 
236
  def _load_image(self, idx: int):
237
  path = self.image_paths[idx]
 
378
  # Use collate function to tokenizer the text and convert the processed images to numpy
379
  def collate_fn(examples):
380
  pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
381
+ # pixel_values = torch.stack([example[0] for example in examples]).numpy()
382
  captions = [example[1] for example in examples]
383
+ inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np",
384
+ truncation=True)
385
 
386
  batch = {
387
  "pixel_values": pixel_values,
src/hybrid_clip/utils/roco_dataset.ipynb CHANGED
The diff for this file is too large to render. See raw diff