|
import os
|
|
import re
|
|
from importlib import import_module
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
|
|
|
import torch.distributed as dist
|
|
from torch.nn.utils import parameters_to_vector, vector_to_parameters
|
|
|
|
from .discriminator import Discriminator
|
|
|
|
from utils import interact
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args):
|
|
super(Model, self).__init__()
|
|
|
|
self.args = args
|
|
self.device = args.device
|
|
self.n_GPUs = args.n_GPUs
|
|
self.save_dir = os.path.join(args.save_dir, 'models')
|
|
os.makedirs(self.save_dir, exist_ok=True)
|
|
|
|
module = import_module('model.' + args.model)
|
|
|
|
self.model = nn.ModuleDict()
|
|
self.model.G = module.build_model(args)
|
|
if self.args.loss.lower().find('adv') >= 0:
|
|
self.model.D = Discriminator(self.args)
|
|
else:
|
|
self.model.D = None
|
|
|
|
self.to(args.device, dtype=args.dtype, non_blocking=True)
|
|
self.load(args.load_epoch, path=args.pretrained)
|
|
|
|
def parallelize(self):
|
|
if self.args.device_type == 'cuda':
|
|
if self.args.distributed:
|
|
Parallel = DistributedDataParallel
|
|
parallel_args = {
|
|
"device_ids": [self.args.rank],
|
|
"output_device": self.args.rank,
|
|
}
|
|
else:
|
|
Parallel = DataParallel
|
|
parallel_args = {
|
|
'device_ids': list(range(self.n_GPUs)),
|
|
'output_device': self.args.rank
|
|
}
|
|
|
|
for model_key in self.model:
|
|
if self.model[model_key] is not None:
|
|
self.model[model_key] = Parallel(self.model[model_key], **parallel_args)
|
|
|
|
def forward(self, input):
|
|
return self.model.G(input)
|
|
|
|
def _save_path(self, epoch):
|
|
model_path = os.path.join(self.save_dir, 'model-{:d}.pt'.format(epoch))
|
|
return model_path
|
|
|
|
def state_dict(self):
|
|
state_dict = {}
|
|
for model_key in self.model:
|
|
if self.model[model_key] is not None:
|
|
parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel))
|
|
if parallelized:
|
|
state_dict[model_key] = self.model[model_key].module.state_dict()
|
|
else:
|
|
state_dict[model_key] = self.model[model_key].state_dict()
|
|
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict, strict=True):
|
|
for model_key in self.model:
|
|
parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel))
|
|
if model_key in state_dict:
|
|
if parallelized:
|
|
self.model[model_key].module.load_state_dict(state_dict[model_key], strict)
|
|
else:
|
|
self.model[model_key].load_state_dict(state_dict[model_key], strict)
|
|
|
|
def save(self, epoch):
|
|
torch.save(self.state_dict(), self._save_path(epoch))
|
|
|
|
def load(self, epoch=None, path=None):
|
|
if path:
|
|
model_name = path
|
|
elif isinstance(epoch, int):
|
|
if epoch < 0:
|
|
epoch = self.get_last_epoch()
|
|
if epoch == 0:
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
model_name = self._save_path(epoch)
|
|
else:
|
|
raise Exception('no epoch number or model path specified!')
|
|
|
|
print('Loading model from {}'.format(model_name))
|
|
state_dict = torch.load(model_name, map_location=self.args.device)
|
|
self.load_state_dict(state_dict)
|
|
|
|
return
|
|
|
|
def synchronize(self):
|
|
if self.args.distributed:
|
|
|
|
vector = parameters_to_vector(self.parameters())
|
|
|
|
dist.broadcast(vector, 0)
|
|
if self.args.rank != 0:
|
|
vector_to_parameters(vector, self.parameters())
|
|
|
|
del vector
|
|
|
|
return
|
|
|
|
def get_last_epoch(self):
|
|
model_list = sorted(os.listdir(self.save_dir))
|
|
if len(model_list) == 0:
|
|
epoch = 0
|
|
else:
|
|
epoch = int(re.findall('\\d+', model_list[-1])[0])
|
|
|
|
return epoch
|
|
|
|
def print(self):
|
|
print(self.model)
|
|
|
|
return
|
|
|