|
import time |
|
|
|
import numpy as np |
|
import torch |
|
import tqdm |
|
from torch import optim |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
|
|
from utils.pt_util import restore_model, restore_objects, save_model, save_objects |
|
from data_proc.triplet_loss_dataset import FBanksTripletDataset |
|
from models.triplet_loss_model import FBankTripletLossNet |
|
|
|
|
|
def _get_cosine_distance(a, b): |
|
return 1 - F.cosine_similarity(a, b) |
|
|
|
|
|
def train(model, device, train_loader, optimizer, epoch, log_interval): |
|
model.train() |
|
losses = [] |
|
positive_accuracy = 0 |
|
negative_accuracy = 0 |
|
|
|
postitive_distances = [] |
|
negative_distances = [] |
|
|
|
for batch_idx, ((ax, ay), (px, py), (nx, ny)) in enumerate(tqdm.tqdm(train_loader)): |
|
ax, px, nx = ax.to(device), px.to(device), nx.to(device) |
|
optimizer.zero_grad() |
|
a_out, p_out, n_out = model(ax, px, nx) |
|
loss = model.loss(a_out, p_out, n_out) |
|
losses.append(loss.item()) |
|
|
|
with torch.no_grad(): |
|
p_distance = _get_cosine_distance(a_out, p_out) |
|
postitive_distances.append(torch.mean(p_distance).item()) |
|
|
|
n_distance = _get_cosine_distance(a_out, n_out) |
|
negative_distances.append(torch.mean(n_distance).item()) |
|
|
|
positive_distance_mean = np.mean(postitive_distances) |
|
negative_distance_mean = np.mean(negative_distances) |
|
|
|
positive_std = np.std(postitive_distances) |
|
threshold = positive_distance_mean + 3 * positive_std |
|
|
|
positive_results = p_distance < threshold |
|
positive_accuracy += torch.sum(positive_results).item() |
|
|
|
negative_results = n_distance >= threshold |
|
negative_accuracy += torch.sum(negative_results).item() |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
if batch_idx % log_interval == 0: |
|
print('{} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
|
time.ctime(time.time()), |
|
epoch, batch_idx * len(ax), len(train_loader.dataset), |
|
100. * batch_idx / len(train_loader), loss.item())) |
|
|
|
positive_distance_mean = np.mean(postitive_distances) |
|
negative_distance_mean = np.mean(negative_distances) |
|
print('Train Set: positive_distance_mean: {}, negative_distance_mean: {}, std: {}, threshold: {}'.format( |
|
positive_distance_mean, negative_distance_mean, positive_std, threshold)) |
|
|
|
positive_accuracy_mean = 100. * positive_accuracy / len(train_loader.dataset) |
|
negative_accuracy_mean = 100. * negative_accuracy / len(train_loader.dataset) |
|
return np.mean(losses), positive_accuracy_mean, negative_accuracy_mean |
|
|
|
|
|
def test(model, device, test_loader, log_interval=None): |
|
model.eval() |
|
losses = [] |
|
positive_accuracy = 0 |
|
negative_accuracy = 0 |
|
|
|
postitive_distances = [] |
|
negative_distances = [] |
|
|
|
with torch.no_grad(): |
|
for batch_idx, ((ax, ay), (px, py), (nx, ny)) in enumerate(tqdm.tqdm(test_loader)): |
|
ax, px, nx = ax.to(device), px.to(device), nx.to(device) |
|
a_out, p_out, n_out = model(ax, px, nx) |
|
test_loss_on = model.loss(a_out, p_out, n_out, reduction='mean').item() |
|
losses.append(test_loss_on) |
|
|
|
p_distance = _get_cosine_distance(a_out, p_out) |
|
postitive_distances.append(torch.mean(p_distance).item()) |
|
|
|
n_distance = _get_cosine_distance(a_out, n_out) |
|
negative_distances.append(torch.mean(n_distance).item()) |
|
|
|
positive_distance_mean = np.mean(postitive_distances) |
|
negative_distance_mean = np.mean(negative_distances) |
|
|
|
positive_std = np.std(postitive_distances) |
|
threshold = positive_distance_mean + 3 * positive_std |
|
|
|
|
|
positive_results = p_distance < threshold |
|
positive_accuracy += torch.sum(positive_results).item() |
|
|
|
negative_results = n_distance >= threshold |
|
negative_accuracy += torch.sum(negative_results).item() |
|
|
|
if log_interval is not None and batch_idx % log_interval == 0: |
|
print('{} Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
|
time.ctime(time.time()), |
|
batch_idx * len(ax), len(test_loader.dataset), |
|
100. * batch_idx / len(test_loader), test_loss_on)) |
|
|
|
test_loss = np.mean(losses) |
|
positive_accuracy_mean = 100. * positive_accuracy / len(test_loader.dataset) |
|
negative_accuracy_mean = 100. * negative_accuracy / len(test_loader.dataset) |
|
|
|
positive_distance_mean = np.mean(postitive_distances) |
|
negative_distance_mean = np.mean(negative_distances) |
|
print('Test Set: positive_distance_mean: {}, negative_distance_mean: {}, std: {}, threshold: {}'.format( |
|
positive_distance_mean, negative_distance_mean, positive_std, threshold)) |
|
|
|
print( |
|
'\nTest set: Average loss: {:.4f}, Positive Accuracy: {}/{} ({:.0f}%), Negative Accuracy: {}/{} ({:.0f}%)\n'.format( |
|
test_loss, positive_accuracy, len(test_loader.dataset), positive_accuracy_mean, negative_accuracy, |
|
len(test_loader.dataset), negative_accuracy_mean)) |
|
return test_loss, positive_accuracy_mean, negative_accuracy_mean |
|
|
|
|
|
def main(): |
|
model_path = 'siamese_fbanks_saved/' |
|
use_cuda = True |
|
device = torch.device("cuda" if use_cuda else "cpu") |
|
print('using device', device) |
|
|
|
import multiprocessing |
|
print('num cpus:', multiprocessing.cpu_count()) |
|
|
|
kwargs = {'num_workers': multiprocessing.cpu_count(), |
|
'pin_memory': True} if use_cuda else {} |
|
|
|
train_dataset = FBanksTripletDataset('fbanks_train') |
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, **kwargs) |
|
|
|
test_dataset = FBanksTripletDataset('fbanks_test') |
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, **kwargs) |
|
|
|
model = FBankTripletLossNet(margin=0.2).to(device) |
|
model = restore_model(model, model_path) |
|
last_epoch, max_accuracy, train_losses, test_losses, train_positive_accuracies, train_negative_accuracies, \ |
|
test_positive_accuracies, test_negative_accuracies = restore_objects(model_path, (0, 0, [], [], [], [], [], [])) |
|
|
|
start = last_epoch + 1 if max_accuracy > 0 else 0 |
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.0005) |
|
|
|
for epoch in range(start, start + 20): |
|
train_loss, train_positive_accuracy, train_negative_accuracy = train(model, device, train_loader, optimizer, |
|
epoch, 500) |
|
test_loss, test_positive_accuracy, test_negative_accuracy = test(model, device, test_loader) |
|
print('After epoch: {}, train loss is : {}, test loss is: {}, ' |
|
'train positive accuracy: {}, train negative accuracy: {}' |
|
'tes positive accuracy: {}, and test negative accuracy: {} ' |
|
.format(epoch, train_loss, test_loss, train_positive_accuracy, train_negative_accuracy, |
|
test_positive_accuracy, test_negative_accuracy)) |
|
|
|
train_losses.append(train_loss) |
|
test_losses.append(test_loss) |
|
train_positive_accuracies.append(train_positive_accuracy) |
|
test_positive_accuracies.append(test_positive_accuracy) |
|
|
|
train_negative_accuracies.append(train_negative_accuracy) |
|
test_negative_accuracies.append(test_negative_accuracy) |
|
|
|
test_accuracy = (test_positive_accuracy + test_negative_accuracy) / 2 |
|
|
|
if test_accuracy > max_accuracy: |
|
max_accuracy = test_accuracy |
|
save_model(model, epoch, model_path) |
|
save_objects((epoch, max_accuracy, train_losses, test_losses, train_positive_accuracies, |
|
train_negative_accuracies, test_positive_accuracies, test_negative_accuracies), |
|
epoch, model_path) |
|
print('saved epoch: {} as checkpoint'.format(epoch)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |