watermelon2 / train_watermelon.py
Xalphinions's picture
Upload folder using huggingface_hub
5900417 verified
import os
import time
import torch
import torchaudio
import torchvision
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import sys
# Add parent directory to path to import the preprocess functions
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from preprocess import process_audio_data, process_image_data
# Print library versions
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
# Device selection
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"\033[92mINFO\033[0m: Using device: {device}")
# Hyperparameters
batch_size = 16
epochs = 2
learning_rate = 0.0001
# Model save directory
os.makedirs("models/", exist_ok=True)
class WatermelonDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.samples = []
# Walk through the directory structure
for sweetness_dir in os.listdir(data_dir):
sweetness = float(sweetness_dir)
sweetness_path = os.path.join(data_dir, sweetness_dir)
if os.path.isdir(sweetness_path):
for id_dir in os.listdir(sweetness_path):
id_path = os.path.join(sweetness_path, id_dir)
if os.path.isdir(id_path):
audio_file = os.path.join(id_path, f"{id_dir}.wav")
image_file = os.path.join(id_path, f"{id_dir}.jpg")
if os.path.exists(audio_file) and os.path.exists(image_file):
self.samples.append((audio_file, image_file, sweetness))
print(f"\033[92mINFO\033[0m: Loaded {len(self.samples)} samples from {data_dir}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
audio_path, image_path, label = self.samples[idx]
# Load and process audio
try:
waveform, sample_rate = torchaudio.load(audio_path)
mfcc = process_audio_data(waveform, sample_rate)
# Load and process image
image = torchvision.io.read_image(image_path)
image = image.float()
processed_image = process_image_data(image)
return mfcc, processed_image, torch.tensor(label).float()
except Exception as e:
print(f"\033[91mERR!\033[0m: Error processing sample {idx}: {e}")
# Return a fallback sample or skip this sample
# For simplicity, we'll return the first sample again
if idx == 0: # Prevent infinite recursion
raise e
return self.__getitem__(0)
class WatermelonModel(torch.nn.Module):
def __init__(self):
super(WatermelonModel, self).__init__()
# LSTM for audio features
self.lstm = torch.nn.LSTM(
input_size=376, hidden_size=64, num_layers=2, batch_first=True
)
self.lstm_fc = torch.nn.Linear(
64, 128
) # Convert LSTM output to 128-dim for merging
# ResNet50 for image features
self.resnet = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
self.resnet.fc = torch.nn.Linear(
self.resnet.fc.in_features, 128
) # Convert ResNet output to 128-dim for merging
# Fully connected layers for final prediction
self.fc1 = torch.nn.Linear(256, 64)
self.fc2 = torch.nn.Linear(64, 1)
self.relu = torch.nn.ReLU()
def forward(self, mfcc, image):
# LSTM branch
lstm_output, _ = self.lstm(mfcc)
lstm_output = lstm_output[:, -1, :] # Use the output of the last time step
lstm_output = self.lstm_fc(lstm_output)
# ResNet branch
resnet_output = self.resnet(image)
# Concatenate LSTM and ResNet outputs
merged = torch.cat((lstm_output, resnet_output), dim=1)
# Fully connected layers
output = self.relu(self.fc1(merged))
output = self.fc2(output)
return output
def train_model(data_dir, output_dir="models/"):
# Create dataset
dataset = WatermelonDataset(data_dir)
n_samples = len(dataset)
# Split dataset
train_size = int(0.7 * n_samples)
val_size = int(0.2 * n_samples)
test_size = n_samples - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size, test_size]
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Initialize model
model = WatermelonModel().to(device)
# Loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# TensorBoard
writer = SummaryWriter("runs/")
global_step = 0
print(f"\033[92mINFO\033[0m: Training model for {epochs} epochs")
print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
print(f"\033[92mINFO\033[0m: Batch size: {batch_size}")
# Training loop
for epoch in range(epochs):
print(f"\033[92mINFO\033[0m: Training epoch ({epoch+1}/{epochs})")
model.train()
running_loss = 0.0
for i, (mfcc, image, label) in enumerate(train_loader):
try:
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
optimizer.zero_grad()
output = model(mfcc, image)
label = label.view(-1, 1).float()
loss = criterion(output, label)
loss.backward()
optimizer.step()
running_loss += loss.item()
writer.add_scalar("Training Loss", loss.item(), global_step)
global_step += 1
if i % 10 == 0:
print(f"\033[92mINFO\033[0m: Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}")
except Exception as e:
print(f"\033[91mERR!\033[0m: Error in training batch {i}: {e}")
continue
# Validation phase
model.eval()
val_loss = 0.0
with torch.no_grad():
for i, (mfcc, image, label) in enumerate(val_loader):
try:
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
output = model(mfcc, image)
label = label.view(-1, 1).float()
loss = criterion(output, label)
val_loss += loss.item()
except Exception as e:
print(f"\033[91mERR!\033[0m: Error in validation batch {i}: {e}")
continue
avg_train_loss = running_loss / len(train_loader) if len(train_loader) > 0 else float('inf')
avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else float('inf')
# Record validation loss
writer.add_scalar("Validation Loss", avg_val_loss, epoch)
print(
f"Epoch [{epoch+1}/{epochs}], Training Loss: {avg_train_loss:.4f}, "
f"Validation Loss: {avg_val_loss:.4f}"
)
# Save model checkpoint
timestamp = time.strftime("%Y%m%d-%H%M%S")
model_path = os.path.join(output_dir, f"model_{epoch+1}_{timestamp}.pt")
torch.save(model.state_dict(), model_path)
print(
f"\033[92mINFO\033[0m: Model checkpoint epoch [{epoch+1}/{epochs}] saved: {model_path}"
)
# Save final model
final_model_path = os.path.join(output_dir, "watermelon_model_final.pt")
torch.save(model.state_dict(), final_model_path)
print(f"\033[92mINFO\033[0m: Final model saved: {final_model_path}")
print(f"\033[92mINFO\033[0m: Training complete")
return final_model_path
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Train the Watermelon Sweetness Prediction Model")
parser.add_argument(
"--data_dir",
type=str,
default="../cleaned",
help="Path to the cleaned dataset directory"
)
parser.add_argument(
"--output_dir",
type=str,
default="models/",
help="Directory to save model checkpoints and the final model"
)
args = parser.parse_args()
# Ensure output directory exists
os.makedirs(args.output_dir, exist_ok=True)
# Train the model
final_model_path = train_model(args.data_dir, args.output_dir)
print(f"\033[92mINFO\033[0m: Training completed successfully!")
print(f"\033[92mINFO\033[0m: Final model saved at: {final_model_path}")