File size: 2,830 Bytes
37b3db0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import tarfile
import io
import pandas as pd
import ast
from tqdm import tqdm
import argparse


def create_webdataset(csv_file, output_dir, parent_dataset_path, tar_size=1000):
    os.makedirs(output_dir, exist_ok=True)

    with open(csv_file, newline='') as f:
        reader = pd.read_csv(csv_file, delimiter=',')

        tar_index = 0
        file_count = 0
        tar = None

        for row in tqdm(reader.values):
            if file_count % tar_size == 0:
                if tar:
                    tar.close()
                tar_index += 1
                tar_path = os.path.join(output_dir, f"dataset-{tar_index:06d}.tar")
                tar = tarfile.open(tar_path, 'w')

            filename = ast.literal_eval(row[0])[0]
            label = ast.literal_eval(row[1])
            all_caption = ast.literal_eval(row[2])
            caption = ''
            for single_caption in all_caption: caption += single_caption.strip('.') + "._chexpert_"
            # Read the image file
            image_path = os.path.join(parent_dataset_path, filename)
            with open(image_path, 'rb') as img_file:
                img_data = img_file.read()

            # Create an in-memory tarfile
            img_tarinfo = tarfile.TarInfo(name=f"{file_count:06d}.jpg")
            img_tarinfo.size = len(img_data)
            tar.addfile(img_tarinfo, io.BytesIO(img_data))

            # Add label.txt to the tarfile
            label_data = label[0].encode('utf-8')
            label_tarinfo = tarfile.TarInfo(name=f"{file_count:06d}.cls")
            label_tarinfo.size = len(label_data)
            # tar.addfile(label_tarinfo, io.BytesIO(label_data))

            # Add caption.txt to the tarfile
            caption_data = caption.encode('utf-8')
            caption_tarinfo = tarfile.TarInfo(name=f"{file_count:06d}.txt")
            caption_tarinfo.size = len(caption_data)
            tar.addfile(caption_tarinfo, io.BytesIO(caption_data))

            file_count += 1

        if tar:
            tar.close()


if __name__ == "__main__":
    # Argument parser setup
    parser = argparse.ArgumentParser(description="Create a WebDataset from CSV")
    parser.add_argument('--csv_file', type=str, required=True, help="Path to the CSV file")
    parser.add_argument('--output_dir', type=str, required=True, help="Directory to store the output tar files")
    parser.add_argument('--parent_dataset_path', type=str, required=True,
                        help="Path to the parent dataset containing images")
    parser.add_argument('--tar_size', type=int, default=1000, help="Number of files per tar file")

    # Parse the arguments
    args = parser.parse_args()

    # Call the function with the parsed arguments
    create_webdataset(args.csv_file, args.output_dir, args.parent_dataset_path, args.tar_size)