muzairkhattak
first commit for the demo
37b3db0
import os
import tarfile
import io
import pandas as pd
import ast
from tqdm import tqdm
import argparse
def create_webdataset(csv_file, output_dir, parent_dataset_path, tar_size=1000):
os.makedirs(output_dir, exist_ok=True)
with open(csv_file, newline='') as f:
reader = pd.read_csv(csv_file, delimiter=',')
tar_index = 0
file_count = 0
tar = None
for row in tqdm(reader.values):
if file_count % tar_size == 0:
if tar:
tar.close()
tar_index += 1
tar_path = os.path.join(output_dir, f"dataset-{tar_index:06d}.tar")
tar = tarfile.open(tar_path, 'w')
filename = ast.literal_eval(row[0])[0]
label = ast.literal_eval(row[1])
all_caption = ast.literal_eval(row[2])
caption = ''
for single_caption in all_caption: caption += single_caption + "._radimagenet_"
# Read the image file
image_path = os.path.join(parent_dataset_path, filename)
with open(image_path, 'rb') as img_file:
img_data = img_file.read()
# Create an in-memory tarfile
img_tarinfo = tarfile.TarInfo(name=f"{file_count:06d}.jpg")
img_tarinfo.size = len(img_data)
tar.addfile(img_tarinfo, io.BytesIO(img_data))
# Add label.txt to the tarfile
label_data = label[0].encode('utf-8')
label_tarinfo = tarfile.TarInfo(name=f"{file_count:06d}.cls")
label_tarinfo.size = len(label_data)
# tar.addfile(label_tarinfo, io.BytesIO(label_data))
# Add caption.txt to the tarfile
caption_data = caption.encode('utf-8')
caption_tarinfo = tarfile.TarInfo(name=f"{file_count:06d}.txt")
caption_tarinfo.size = len(caption_data)
tar.addfile(caption_tarinfo, io.BytesIO(caption_data))
file_count += 1
if tar:
tar.close()
if __name__ == "__main__":
# Argument parser setup
parser = argparse.ArgumentParser(description="Create a WebDataset from CSV")
parser.add_argument('--csv_file', type=str, required=True, help="Path to the CSV file")
parser.add_argument('--output_dir', type=str, required=True, help="Directory to store the output tar files")
parser.add_argument('--parent_dataset_path', type=str, required=True,
help="Path to the parent dataset containing images")
parser.add_argument('--tar_size', type=int, default=1000, help="Number of files per tar file")
# Parse the arguments
args = parser.parse_args()
# Call the function with the parsed arguments
create_webdataset(args.csv_file, args.output_dir, args.parent_dataset_path, args.tar_size)