frugal_cviz / src /load_data.py
sycod
first train ok, oversampling begun
38dbb38
raw
history blame
7 kB
"""Load dataset and save locally in selected format"""
from datasets import load_dataset
import logging
import os
import pandas as pd
import shutil
import subprocess
import yaml
# Logging configuration (see all outputs, even DEBUG or INFO)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# local config
with open("config.yaml", "r") as f:
cfg = yaml.safe_load(f)
REPO_ID = cfg["repo_id"]
SPLIT_SIZE = cfg["split_size"]
RDM_SEED = cfg["rdm_seed"]
OUTPUT_DIR = cfg["data_root_dir"]
RAW_DATA_DIR = os.path.join(OUTPUT_DIR, cfg["raw_data_dir"])
CLR_CACHE_SCRIPT = cfg["clr_hf_cache_script_abs_path"]
DB_INFO_URI = os.path.join(OUTPUT_DIR, cfg["db_info_uri"])
# Save in Ultralytics format
def save_ultralytics_format(dataset_split, split, IMAGE_DIR, LABEL_DIR):
"""Save a dataset split into the Ultralytics format.
Args:
dataset_split: The dataset split (e.g. dataset["train"])
split: "train", "test" or "val"
"""
image_split_dir = os.path.join(IMAGE_DIR, split)
label_split_dir = os.path.join(LABEL_DIR, split)
for example in dataset_split:
# Save image to appropriate folder
image = example["image"] # PIL.Image.Image
image_name = example["image_name"] # Original file name
output_image_path = os.path.join(image_split_dir, image_name)
# Save image object to disk
image.save(output_image_path)
# Save label
annotations = example["annotations"]
label_name = image_name.replace(".jpg", ".txt").replace(".png", ".txt")
output_label_path = os.path.join(label_split_dir, label_name)
# Save label file
with open(output_label_path, "w") as label_file:
label_file.write(annotations)
logging.info(f"Dataset {split} split exported to Ultralytics format")
def create_df(ds, split_name, output_dir):
"""Create dataframe from dataset"""
df = pd.DataFrame(
[[i.size[0], i.size[1], i.format, i.mode] for i in ds["image"]],
columns=["width", "height", "format", "mode"],
)
df["name"] = ds["image_name"]
df["split"] = split_name
df["uri"] = df["name"].apply(
lambda x: os.path.join(output_dir, "images", split_name, x)
)
df["annotations"] = ds["annotations"]
df["partner"] = ds["partner"]
df["camera"] = ds["camera"]
df["timestamp"] = ds["date"]
return df
def load_raw_data():
"""Main function for downloading, splitting and formatting data"""
# Check if data information already exists before eventually loading model
if os.path.exists(DB_INFO_URI):
df = pd.read_csv(DB_INFO_URI, index_col=0)
return df
# Load data
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(RAW_DATA_DIR, exist_ok=True)
logging.info("⚙️ Dataset loading...")
dataset = load_dataset(REPO_ID)
train_test = dataset["train"].train_test_split(test_size=SPLIT_SIZE, seed=RDM_SEED)
ds_train = train_test["train"]
ds_val = dataset["val"]
ds_test = train_test["test"]
logging.info("✅ Dataset loaded in cache folder")
# Create directory structure
IMAGE_DIR = os.path.join(RAW_DATA_DIR, "images")
LABEL_DIR = os.path.join(RAW_DATA_DIR, "labels")
for split in ["train", "val", "test"]:
os.makedirs(os.path.join(IMAGE_DIR, split), exist_ok=True)
os.makedirs(os.path.join(LABEL_DIR, split), exist_ok=True)
# Save dataset splits
save_ultralytics_format(ds_train, "train", IMAGE_DIR, LABEL_DIR)
save_ultralytics_format(ds_val, "val", IMAGE_DIR, LABEL_DIR)
save_ultralytics_format(ds_test, "test", IMAGE_DIR, LABEL_DIR)
# Create global dataframe from splits
# Separate train to save memory
df_train_1 = create_df(ds_train[:6000], "train", RAW_DATA_DIR)
df_train_2 = create_df(ds_train[6000:12000], "train", RAW_DATA_DIR)
df_train_3 = create_df(ds_train[12000:18000], "train", RAW_DATA_DIR)
df_train_4 = create_df(ds_train[18000:], "train", RAW_DATA_DIR)
df_val = create_df(ds_val, "val", RAW_DATA_DIR)
df_test = create_df(ds_test, "test", RAW_DATA_DIR)
# Save as one CSV
df = pd.concat(
[df_train_1, df_train_2, df_train_3, df_train_4, df_val, df_test],
axis=0,
ignore_index=True,
)
# Create label column for classification
df["label"] = "smoke"
df.loc[df["annotations"].isna() | (df["annotations"] == ""), "label"] = "no_smoke"
# Reorder columns
df = df.loc[
:,
[
"name",
"label",
"split",
"format",
"mode",
"width",
"height",
"camera",
"partner",
"timestamp",
"annotations",
"uri",
],
]
# Save as CSV
with open(DB_INFO_URI, "wb") as f:
df.to_csv(f)
# Clear HF default cache folder after it is done (6GB)
# 💡 Check first if path up-to-date in "clear_hf_cache.sh"
logging.info("🧹 Removing HF default cache folder...")
result = subprocess.run(["bash", CLR_CACHE_SCRIPT], capture_output=True, text=True)
# logging.info(result.stdout)
logging.info("✅ HF Cache folder removed")
return df
def clean_df(df):
"""Filter data to keep only necessary"""
# Filter columns
df = df[["name", "label", "split", "uri"]]
# Remove ".jpg" in name
df.loc[:, "name"] = df.name.apply(lambda x: x[:-4])
return df
def format_data_keras(df):
"""Format data for Keras models"""
if not os.path.exists(OUTPUT_DIR):
logging.warning(f"{OUTPUT_DIR} doesn't exist: (re)load data first")
return df
# Create Keras parent folder
keras_dir = os.path.join(OUTPUT_DIR, "keras")
# Check if data already exists
if os.path.exists(keras_dir) and len(os.listdir("./data/keras")) > 0:
logging.info(f"{keras_dir} already exists: data already formatted")
return df
os.makedirs(keras_dir, exist_ok=True)
# Create splits folders
for split in df.split.unique():
split_dir = os.path.join(keras_dir, split)
os.makedirs(split_dir, exist_ok=True)
# Create labels folders
for label in df.label.unique():
label_dir = os.path.join(split_dir, label)
os.makedirs(label_dir, exist_ok=True)
# Copy images to new URI and update in dataframe
df.loc[:, "uri_dest"] = df.apply(
lambda x: os.path.join(OUTPUT_DIR, "keras", x["split"], x["label"], x["name"])
+ ".jpg",
axis=1,
)
df.apply(lambda x: shutil.copy2(x["uri"], x["uri_dest"]), axis=1)
df.drop(columns="uri", inplace=True)
df.rename(columns={"uri_dest": "uri"}, inplace=True)
return df
def oversample_class(df):
"""Oversample an under-represented class"""
count_df = df.groupby(["split", "label"]).size().reset_index(name="count")
count_df = count_df.loc[count_df["split"] != "val"]
return df
if __name__ == "__main__":
help()