File size: 4,744 Bytes
8e2b754 4463ade 8e2b754 98c2b8e 8e2b754 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 2daf3c7 4463ade 8e2b754 98c2b8e 8e2b754 2daf3c7 4463ade 8e2b754 2daf3c7 8e2b754 98c2b8e 8e2b754 2daf3c7 8e2b754 4463ade 98c2b8e 8e2b754 4463ade 2daf3c7 98c2b8e 4463ade 2daf3c7 8e2b754 98c2b8e 8e2b754 4463ade 8e2b754 4463ade 8e2b754 4463ade 8e2b754 98c2b8e 8e2b754 4463ade |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import argparse
import json
import logging
import os
import time
import urllib.error
import urllib.request
from typing import List
import pandas as pd
from tqdm import tqdm
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def split_and_save_datasets(
lines: List[str], output_dir: str, train_proportion: float, valid_proportion: float
):
total_lines = len(lines)
train_lines = lines[: int(total_lines * train_proportion)]
valid_lines = lines[
int(total_lines * train_proportion) : int(
total_lines * (train_proportion + valid_proportion)
)
]
test_lines = lines[int(total_lines * (train_proportion + valid_proportion)) :]
with open(f"{output_dir}/train_dataset.json", "w") as f:
f.write("\n".join(train_lines))
with open(f"{output_dir}/valid_dataset.json", "w") as f:
f.write("\n".join(valid_lines))
with open(f"{output_dir}/test_dataset.json", "w") as f:
f.write("\n".join(test_lines))
def prepare_wit(
tsv: str,
language: str,
output_dir: str,
seed: int,
train_proportion: float,
valid_proportion: float,
backup_period: int,
language_col: str = "language",
caption_col: str = "caption_reference_description",
url_col: str = "image_url",
pause=0.875,
retries: int = 10,
):
os.makedirs(output_dir, exist_ok=True)
logger.info("Loading dataset")
df = pd.read_csv(tsv, sep="\t", engine="python")
existing_files = set(os.listdir(output_dir))
not_exists_condition = ~(
df[url_col].map(lambda x: x.split("/")[-1][-100:]).isin(existing_files)
)
df = df[
(df["language"] == language)
& (~df["caption_reference_description"].isnull())
& not_exists_condition
]
# Shuffle
df = df.sample(frac=1.0, random_state=seed)
logger.info(f"Trying to downloading {df.shape[0]} files")
lines = []
count = 0
try:
with tqdm(total=len(df)) as pbar:
for i, row in tqdm(df.iterrows()):
url = row[url_col]
caption = row[caption_col]
# Trim image file names so that they are no longer than 100 characters
image_filename = url.split("/")[-1][-100:]
image_path = f"{output_dir}/{image_filename}"
for retry in range(retries):
try:
# Download file
urllib.request.urlretrieve(url, image_path)
lines.append(
json.dumps(
{"image_path": image_path, "captions": [caption]},
ensure_ascii=False,
)
)
count += 1
break
except urllib.error.HTTPError:
time.sleep(pause * 10)
if count % backup_period == 0:
logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
split_and_save_datasets(
lines, output_dir, train_proportion, valid_proportion
)
if retry == retries - 1:
logger.info(f"Skipping {image_filename}")
pbar.update(1)
# Save existing dataset, even upon failure
finally:
split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download and prepare the WIT dataset")
parser.add_argument(
"--tsv",
type=str,
default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv",
)
parser.add_argument("--language", type=str, default="es")
parser.add_argument(
"--output_dir",
type=str,
default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset",
)
parser.add_argument("--random_seed", type=int, default=0)
parser.add_argument("--train_proportion", type=float, default=0.8)
parser.add_argument("--valid_proportion", type=float, default=0.1)
parser.add_argument("--backup_period", type=int, default=1000)
args = parser.parse_args()
assert (
args.train_proportion + args.valid_proportion < 1.0
), "The sum of train_proportion and valid_proportion has to be < 1.0"
prepare_wit(
args.tsv,
args.language,
args.output_dir,
args.random_seed,
args.train_proportion,
args.valid_proportion,
args.backup_period,
)
|