edugp commited on
Commit
04899df
·
1 Parent(s): 5f52f84

Switch to BERTIN model for training script and testing on image

Browse files
Files changed (2) hide show
  1. run-clip.sh +2 -2
  2. test_on_image.py +3 -3
run-clip.sh CHANGED
@@ -1,8 +1,8 @@
1
  python run_hybrid_clip.py \
2
  --output_dir "./output_141230_training_examples" \
3
- --text_model_name_or_path="dccuchile/bert-base-spanish-wwm-cased" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
- --tokenizer_name="dccuchile/bert-base-spanish-wwm-cased" \
6
  --train_file="/home/${USER}/data/wit_scale_converted/train_dataset_scale_converted_98_1_1_split.json" \
7
  --validation_file="/home/${USER}/data/wit_scale_converted/valid_dataset_scale_converted_98_1_1_split.json" \
8
  --do_train \
 
1
  python run_hybrid_clip.py \
2
  --output_dir "./output_141230_training_examples" \
3
+ --text_model_name_or_path="bertin-project/bertin-roberta-base-spanish" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
+ --tokenizer_name="bertin-project/bertin-roberta-base-spanish" \
6
  --train_file="/home/${USER}/data/wit_scale_converted/train_dataset_scale_converted_98_1_1_split.json" \
7
  --validation_file="/home/${USER}/data/wit_scale_converted/valid_dataset_scale_converted_98_1_1_split.json" \
8
  --do_train \
test_on_image.py CHANGED
@@ -23,15 +23,15 @@ def prepare_text(text, tokenizer):
23
  def run_inference(image_path, text, model, tokenizer):
24
  pixel_values = prepare_image(image_path, model)
25
  input_text = prepare_text(text, tokenizer)
26
- model_output = model(input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], token_type_ids=input_text["token_type_ids"], train=False, return_dict=True)
27
  logits = model_output["logits_per_image"]
28
  score = jax.nn.sigmoid(logits)[0][0]
29
  return score
30
 
31
 
32
  if __name__ == "__main__":
33
- model = FlaxHybridCLIP.from_pretrained("clip_spanish_141230_samples")
34
- tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
35
 
36
  image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Santuar.jpg"
37
  text = "Fachada del Santuario"
 
23
  def run_inference(image_path, text, model, tokenizer):
24
  pixel_values = prepare_image(image_path, model)
25
  input_text = prepare_text(text, tokenizer)
26
+ model_output = model(input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], train=False, return_dict=True)
27
  logits = model_output["logits_per_image"]
28
  score = jax.nn.sigmoid(logits)[0][0]
29
  return score
30
 
31
 
32
  if __name__ == "__main__":
33
+ model = FlaxHybridCLIP.from_pretrained("./")
34
+ tokenizer = AutoTokenizer.from_pretrained("bertin-project/bertin-roberta-base-spanish")
35
 
36
  image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Santuar.jpg"
37
  text = "Fachada del Santuario"