speaker_identify / trainer /triplet_loss_train.py
DuyTa's picture
Upload folder using huggingface_hub
f831146 verified
import time
import numpy as np
import torch
import tqdm
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from utils.pt_util import restore_model, restore_objects, save_model, save_objects
from data_proc.triplet_loss_dataset import FBanksTripletDataset
from models.triplet_loss_model import FBankTripletLossNet
def _get_cosine_distance(a, b):
return 1 - F.cosine_similarity(a, b)
def train(model, device, train_loader, optimizer, epoch, log_interval):
model.train()
losses = []
positive_accuracy = 0
negative_accuracy = 0
postitive_distances = []
negative_distances = []
for batch_idx, ((ax, ay), (px, py), (nx, ny)) in enumerate(tqdm.tqdm(train_loader)):
ax, px, nx = ax.to(device), px.to(device), nx.to(device)
optimizer.zero_grad()
a_out, p_out, n_out = model(ax, px, nx)
loss = model.loss(a_out, p_out, n_out)
losses.append(loss.item())
with torch.no_grad():
p_distance = _get_cosine_distance(a_out, p_out)
postitive_distances.append(torch.mean(p_distance).item())
n_distance = _get_cosine_distance(a_out, n_out)
negative_distances.append(torch.mean(n_distance).item())
positive_distance_mean = np.mean(postitive_distances)
negative_distance_mean = np.mean(negative_distances)
positive_std = np.std(postitive_distances)
threshold = positive_distance_mean + 3 * positive_std
positive_results = p_distance < threshold
positive_accuracy += torch.sum(positive_results).item()
negative_results = n_distance >= threshold
negative_accuracy += torch.sum(negative_results).item()
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('{} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
time.ctime(time.time()),
epoch, batch_idx * len(ax), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
positive_distance_mean = np.mean(postitive_distances)
negative_distance_mean = np.mean(negative_distances)
print('Train Set: positive_distance_mean: {}, negative_distance_mean: {}, std: {}, threshold: {}'.format(
positive_distance_mean, negative_distance_mean, positive_std, threshold))
positive_accuracy_mean = 100. * positive_accuracy / len(train_loader.dataset)
negative_accuracy_mean = 100. * negative_accuracy / len(train_loader.dataset)
return np.mean(losses), positive_accuracy_mean, negative_accuracy_mean
def test(model, device, test_loader, log_interval=None):
model.eval()
losses = []
positive_accuracy = 0
negative_accuracy = 0
postitive_distances = []
negative_distances = []
with torch.no_grad():
for batch_idx, ((ax, ay), (px, py), (nx, ny)) in enumerate(tqdm.tqdm(test_loader)):
ax, px, nx = ax.to(device), px.to(device), nx.to(device)
a_out, p_out, n_out = model(ax, px, nx)
test_loss_on = model.loss(a_out, p_out, n_out, reduction='mean').item()
losses.append(test_loss_on)
p_distance = _get_cosine_distance(a_out, p_out)
postitive_distances.append(torch.mean(p_distance).item())
n_distance = _get_cosine_distance(a_out, n_out)
negative_distances.append(torch.mean(n_distance).item())
positive_distance_mean = np.mean(postitive_distances)
negative_distance_mean = np.mean(negative_distances)
positive_std = np.std(postitive_distances)
threshold = positive_distance_mean + 3 * positive_std
# experiment with this threshold distance to play with accuracy numbers
positive_results = p_distance < threshold
positive_accuracy += torch.sum(positive_results).item()
negative_results = n_distance >= threshold
negative_accuracy += torch.sum(negative_results).item()
if log_interval is not None and batch_idx % log_interval == 0:
print('{} Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
time.ctime(time.time()),
batch_idx * len(ax), len(test_loader.dataset),
100. * batch_idx / len(test_loader), test_loss_on))
test_loss = np.mean(losses)
positive_accuracy_mean = 100. * positive_accuracy / len(test_loader.dataset)
negative_accuracy_mean = 100. * negative_accuracy / len(test_loader.dataset)
positive_distance_mean = np.mean(postitive_distances)
negative_distance_mean = np.mean(negative_distances)
print('Test Set: positive_distance_mean: {}, negative_distance_mean: {}, std: {}, threshold: {}'.format(
positive_distance_mean, negative_distance_mean, positive_std, threshold))
print(
'\nTest set: Average loss: {:.4f}, Positive Accuracy: {}/{} ({:.0f}%), Negative Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, positive_accuracy, len(test_loader.dataset), positive_accuracy_mean, negative_accuracy,
len(test_loader.dataset), negative_accuracy_mean))
return test_loss, positive_accuracy_mean, negative_accuracy_mean
def main():
model_path = 'siamese_fbanks_saved/'
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
print('using device', device)
import multiprocessing
print('num cpus:', multiprocessing.cpu_count())
kwargs = {'num_workers': multiprocessing.cpu_count(),
'pin_memory': True} if use_cuda else {}
train_dataset = FBanksTripletDataset('fbanks_train')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, **kwargs)
test_dataset = FBanksTripletDataset('fbanks_test')
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, **kwargs)
model = FBankTripletLossNet(margin=0.2).to(device)
model = restore_model(model, model_path)
last_epoch, max_accuracy, train_losses, test_losses, train_positive_accuracies, train_negative_accuracies, \
test_positive_accuracies, test_negative_accuracies = restore_objects(model_path, (0, 0, [], [], [], [], [], []))
start = last_epoch + 1 if max_accuracy > 0 else 0
optimizer = optim.Adam(model.parameters(), lr=0.0005)
for epoch in range(start, start + 20):
train_loss, train_positive_accuracy, train_negative_accuracy = train(model, device, train_loader, optimizer,
epoch, 500)
test_loss, test_positive_accuracy, test_negative_accuracy = test(model, device, test_loader)
print('After epoch: {}, train loss is : {}, test loss is: {}, '
'train positive accuracy: {}, train negative accuracy: {}'
'tes positive accuracy: {}, and test negative accuracy: {} '
.format(epoch, train_loss, test_loss, train_positive_accuracy, train_negative_accuracy,
test_positive_accuracy, test_negative_accuracy))
train_losses.append(train_loss)
test_losses.append(test_loss)
train_positive_accuracies.append(train_positive_accuracy)
test_positive_accuracies.append(test_positive_accuracy)
train_negative_accuracies.append(train_negative_accuracy)
test_negative_accuracies.append(test_negative_accuracy)
test_accuracy = (test_positive_accuracy + test_negative_accuracy) / 2
if test_accuracy > max_accuracy:
max_accuracy = test_accuracy
save_model(model, epoch, model_path)
save_objects((epoch, max_accuracy, train_losses, test_losses, train_positive_accuracies,
train_negative_accuracies, test_positive_accuracies, test_negative_accuracies),
epoch, model_path)
print('saved epoch: {} as checkpoint'.format(epoch))
if __name__ == '__main__':
main()