edugp commited on
Commit
8e2b754
·
1 Parent(s): 30ae6fe

Add training scripts and initial model trained on 1% of the data.

Browse files
.gitattributes CHANGED
@@ -14,3 +14,4 @@
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
clip_spanish_1_percent/config.json ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HybridCLIP"
4
+ ],
5
+ "initializer_factor": 1.0,
6
+ "model_type": "hybrid-clip",
7
+ "projection_dim": 512,
8
+ "seed": 42,
9
+ "text_config": {
10
+ "_name_or_path": "dccuchile/bert-base-spanish-wwm-cased",
11
+ "add_cross_attention": false,
12
+ "architectures": [
13
+ "BertForMaskedLM"
14
+ ],
15
+ "attention_probs_dropout_prob": 0.1,
16
+ "bad_words_ids": null,
17
+ "bos_token_id": null,
18
+ "chunk_size_feed_forward": 0,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "early_stopping": false,
23
+ "encoder_no_repeat_ngram_size": 0,
24
+ "eos_token_id": null,
25
+ "finetuning_task": null,
26
+ "forced_bos_token_id": null,
27
+ "forced_eos_token_id": null,
28
+ "gradient_checkpointing": false,
29
+ "hidden_act": "gelu",
30
+ "hidden_dropout_prob": 0.1,
31
+ "hidden_size": 768,
32
+ "id2label": {
33
+ "0": "LABEL_0",
34
+ "1": "LABEL_1"
35
+ },
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 3072,
38
+ "is_decoder": false,
39
+ "is_encoder_decoder": false,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1
43
+ },
44
+ "layer_norm_eps": 1e-12,
45
+ "length_penalty": 1.0,
46
+ "max_length": 20,
47
+ "max_position_embeddings": 512,
48
+ "min_length": 0,
49
+ "model_type": "bert",
50
+ "no_repeat_ngram_size": 0,
51
+ "num_attention_heads": 12,
52
+ "num_beam_groups": 1,
53
+ "num_beams": 1,
54
+ "num_hidden_layers": 12,
55
+ "num_return_sequences": 1,
56
+ "output_attentions": false,
57
+ "output_hidden_states": false,
58
+ "output_past": true,
59
+ "output_scores": false,
60
+ "pad_token_id": 1,
61
+ "position_embedding_type": "absolute",
62
+ "prefix": null,
63
+ "problem_type": null,
64
+ "pruned_heads": {},
65
+ "remove_invalid_values": false,
66
+ "repetition_penalty": 1.0,
67
+ "return_dict": true,
68
+ "return_dict_in_generate": false,
69
+ "sep_token_id": null,
70
+ "task_specific_params": null,
71
+ "temperature": 1.0,
72
+ "tie_encoder_decoder": false,
73
+ "tie_word_embeddings": true,
74
+ "tokenizer_class": null,
75
+ "top_k": 50,
76
+ "top_p": 1.0,
77
+ "torch_dtype": null,
78
+ "torchscript": false,
79
+ "transformers_version": "4.9.0.dev0",
80
+ "type_vocab_size": 2,
81
+ "use_bfloat16": false,
82
+ "use_cache": true,
83
+ "vocab_size": 31002
84
+ },
85
+ "transformers_version": null,
86
+ "vision_config": {
87
+ "_name_or_path": "",
88
+ "add_cross_attention": false,
89
+ "architectures": null,
90
+ "attention_dropout": 0.0,
91
+ "bad_words_ids": null,
92
+ "bos_token_id": null,
93
+ "chunk_size_feed_forward": 0,
94
+ "decoder_start_token_id": null,
95
+ "diversity_penalty": 0.0,
96
+ "do_sample": false,
97
+ "dropout": 0.0,
98
+ "early_stopping": false,
99
+ "encoder_no_repeat_ngram_size": 0,
100
+ "eos_token_id": null,
101
+ "finetuning_task": null,
102
+ "forced_bos_token_id": null,
103
+ "forced_eos_token_id": null,
104
+ "gradient_checkpointing": false,
105
+ "hidden_act": "quick_gelu",
106
+ "hidden_size": 768,
107
+ "id2label": {
108
+ "0": "LABEL_0",
109
+ "1": "LABEL_1"
110
+ },
111
+ "image_size": 224,
112
+ "initializer_factor": 1.0,
113
+ "initializer_range": 0.02,
114
+ "intermediate_size": 3072,
115
+ "is_decoder": false,
116
+ "is_encoder_decoder": false,
117
+ "label2id": {
118
+ "LABEL_0": 0,
119
+ "LABEL_1": 1
120
+ },
121
+ "layer_norm_eps": 1e-05,
122
+ "length_penalty": 1.0,
123
+ "max_length": 20,
124
+ "min_length": 0,
125
+ "model_type": "clip_vision_model",
126
+ "no_repeat_ngram_size": 0,
127
+ "num_attention_heads": 12,
128
+ "num_beam_groups": 1,
129
+ "num_beams": 1,
130
+ "num_hidden_layers": 12,
131
+ "num_return_sequences": 1,
132
+ "output_attentions": false,
133
+ "output_hidden_states": false,
134
+ "output_scores": false,
135
+ "pad_token_id": null,
136
+ "patch_size": 32,
137
+ "prefix": null,
138
+ "problem_type": null,
139
+ "pruned_heads": {},
140
+ "remove_invalid_values": false,
141
+ "repetition_penalty": 1.0,
142
+ "return_dict": true,
143
+ "return_dict_in_generate": false,
144
+ "sep_token_id": null,
145
+ "task_specific_params": null,
146
+ "temperature": 1.0,
147
+ "tie_encoder_decoder": false,
148
+ "tie_word_embeddings": true,
149
+ "tokenizer_class": null,
150
+ "top_k": 50,
151
+ "top_p": 1.0,
152
+ "torch_dtype": null,
153
+ "torchscript": false,
154
+ "transformers_version": "4.9.0.dev0",
155
+ "use_bfloat16": false
156
+ }
157
+ }
clip_spanish_1_percent/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29e4478aa3195ba626a7051a3d2a8d17bb540b4e68d8d75cca2d549104e586c2
3
+ size 792387416
configuration_hybrid_clip.py ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/eduardogonzalezponferrada/transformers/examples/research_projects/jax-projects/hybrid_clip/configuration_hybrid_clip.py
discard_incorrect_files.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import torch
5
+ from torchvision.io import ImageReadMode, read_image
6
+
7
+ # SUPPORTED_EXTENSIONS = {'PNG', 'JPG', 'png', 'JPEG', 'jpg', 'jpeg'}
8
+
9
+ for split in ["train", "valid", "test"]:
10
+ with open(f"/home/{os.environ['USER']}/data/wit/prepared_dataset/{split}_dataset.json") as f:
11
+ examples = [json.loads(line) for line in f.readlines()]
12
+
13
+ supported_examples = []
14
+ for example in examples:
15
+ try:
16
+ image = read_image(example["image_path"], mode=ImageReadMode.RGB)
17
+ supported_examples.append(json.dumps(example, ensure_ascii=False))
18
+ except Exception as e:
19
+ print(f"Excluding file: {example['image_path']} due to error: {e}")
20
+
21
+ print(f"Total {split} examples: {len(supported_examples)}")
22
+ with open(f"/home/{os.environ['USER']}/data/wit/prepared_dataset/{split}_dataset_filtered.json", "w") as f:
23
+ f.write("\n".join(supported_examples))
modeling_hybrid_clip.py ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/eduardogonzalezponferrada/transformers/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py
prepare_wit.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import time
6
+ import urllib.request
7
+ import urllib.error
8
+
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=1.0, retries: int=5):
16
+ os.makedirs(output_dir, exist_ok=True)
17
+ df = pd.read_csv(tsv, sep="\t", engine="python")
18
+ df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull())]
19
+ # Shuffle
20
+ df = df.sample(frac=1.0, random_state=seed)
21
+ lines = []
22
+ try:
23
+ with tqdm(total=len(df)) as pbar:
24
+ for i, row in tqdm(df.iterrows()):
25
+ url = row[url_col]
26
+ caption = row[caption_col]
27
+ # Trim image file names so that they are no longer than 100 characters
28
+ image_filename = url.split('/')[-1][-100:]
29
+ image_path = f"{output_dir}/{image_filename}"
30
+ for retry in range(retries):
31
+ try:
32
+ # Download file
33
+ urllib.request.urlretrieve(url, image_path)
34
+ lines.append(json.dumps({"image_path": image_path, "captions": [caption]}, ensure_ascii=False))
35
+ break
36
+ except urllib.error.HTTPError as e:
37
+ time.sleep(pause)
38
+ if retry == retries:
39
+ raise ValueError("Rate limit achieved:", e)
40
+ pbar.update(1)
41
+ # Save existing dataset, even upon failure
42
+ finally:
43
+ total_lines = len(lines)
44
+ train_lines = lines[:int(total_lines * train_proportion)]
45
+ valid_lines = lines[int(total_lines * train_proportion):int(total_lines * (train_proportion + valid_proportion))]
46
+ test_lines = lines[int(total_lines * (train_proportion + valid_proportion)):]
47
+
48
+ with open(f"{output_dir}/train_dataset.json", "w") as f:
49
+ f.write("\n".join(train_lines))
50
+
51
+ with open(f"{output_dir}/valid_dataset.json", "w") as f:
52
+ f.write("\n".join(valid_lines))
53
+
54
+ with open(f"{output_dir}/test_dataset.json", "w") as f:
55
+ f.write("\n".join(test_lines))
56
+
57
+ if __name__ == "__main__":
58
+ parser = argparse.ArgumentParser(description = "Download and prepare the WIT dataset")
59
+ parser.add_argument("--tsv", type=str, default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv")
60
+ parser.add_argument("--language", type=str, default="es")
61
+ parser.add_argument("--output_dir", type=str, default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset")
62
+ parser.add_argument("--random_seed", type=int, default=0)
63
+ parser.add_argument("--train_proportion", type=float, default=0.8)
64
+ parser.add_argument("--valid_proportion", type=float, default=0.1)
65
+
66
+ args = parser.parse_args()
67
+ assert args.train_proportion + args.valid_proportion < 1.0, "The sum of train_proportion and valid_proportion has to be < 1.0"
68
+ prepare_wit(args.tsv, args.language, args.output_dir, args.random_seed, args.train_proportion, args.valid_proportion)
run-clip.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HUB_TOKEN=`cat $HOME/.huggingface/token`
2
+ python run_hybrid_clip.py \
3
+ --output_dir "./output_dir" \
4
+ --text_model_name_or_path="dccuchile/bert-base-spanish-wwm-cased" \
5
+ --vision_model_name_or_path="openai/clip-vit-base-patch32" \
6
+ --tokenizer_name="dccuchile/bert-base-spanish-wwm-cased" \
7
+ --train_file="/home/${USER}/data/wit/prepared_dataset/train_dataset_filtered.json" \
8
+ --validation_file="/home/${USER}/data/wit/prepared_dataset/valid_dataset_filtered.json" \
9
+ --do_train --do_eval \
10
+ --num_train_epochs="40" \
11
+ --max_seq_length 96 \
12
+ --per_device_train_batch_size="64" \
13
+ --per_device_eval_batch_size="64" \
14
+ --learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
15
+ --overwrite_output_dir \
16
+ --preprocessing_num_workers 32
17
+ #--push_to_hub
18
+
run_hybrid_clip.py ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/eduardogonzalezponferrada/transformers/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
test_on_image.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import torch
3
+ from torchvision.io import ImageReadMode, read_image
4
+ from transformers import AutoTokenizer
5
+
6
+ from modeling_hybrid_clip import FlaxHybridCLIP
7
+ from run_hybrid_clip import Transform
8
+
9
+ model = FlaxHybridCLIP.from_pretrained("clip_spanish_1_percent")
10
+ tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
11
+
12
+ def prepare_image(image_path):
13
+ image = read_image(image_path, mode=ImageReadMode.RGB)
14
+ preprocess = Transform(model.config.vision_config.image_size)
15
+ preprocess = torch.jit.script(preprocess)
16
+ preprocessed_image = preprocess(image)
17
+ pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
18
+ return pixel_values
19
+
20
+ def prepare_text(text):
21
+ return tokenizer(text, return_tensors="np")
22
+
23
+ def run_inference(image_path, text):
24
+ pixel_values = prepare_image(image_path)
25
+ input_text = prepare_text(text)
26
+ model_output = model(input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], token_type_ids=input_text["token_type_ids"], train=False, return_dict=True)
27
+ logits = model_output["logits_per_image"]
28
+ score = jax.nn.sigmoid(logits)
29
+ return score
30
+
31
+ image_path = "/home/eduardogonzalezponferrada/data/wit/full_dataset/Casa_de_Cultura_%284%29.JPG"
32
+ text = "Patio interior de un edificio"
33
+
34
+ print(run_inference(image_path, text))