|
|
|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
import random |
|
import re |
|
import sys |
|
import time |
|
import traceback |
|
from datetime import datetime, timedelta |
|
|
|
import pkg_resources |
|
|
|
from .gpu import collect_gpu_usage |
|
|
|
idle_timeout_seconds = 10 * 60 |
|
gpu_refressh_interval_seconds = 5 |
|
regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*') |
|
trial_runner_syslogger = None |
|
|
|
|
|
def main_loop(args): |
|
'''main loop logic for trial runner''' |
|
idle_last_time = datetime.now() |
|
gpu_refresh_last_time = datetime.now() - timedelta(minutes=1) |
|
|
|
try: |
|
if args.job_pid_file: |
|
with open(args.job_pid_file, 'w') as job_file: |
|
job_file.write("%d" % os.getpid()) |
|
|
|
trials = dict() |
|
|
|
command_channel = args.command_channel |
|
|
|
while True: |
|
command_type, command_data = command_channel.receive() |
|
if command_type == CommandType.NewTrialJob: |
|
trial_id = command_data["trialId"] |
|
if trial_id in trials.keys(): |
|
trial = trials[trial_id] |
|
if trial.is_running(): |
|
raise Exception('trial %s is running already, cannot start a new one' % trial.id) |
|
else: |
|
del trials[trial_id] |
|
trial = Trial(args, command_data) |
|
trial.run() |
|
trials[trial_id] = trial |
|
elif command_type == CommandType.KillTrialJob: |
|
trial_id = command_data |
|
if trial_id in trials.keys(): |
|
trial = trials[trial_id] |
|
trial.kill(command_data) |
|
elif command_type == CommandType.SendTrialJobParameter: |
|
trial_id = command_data["trialId"] |
|
if trial_id in trials.keys(): |
|
trial = trials[trial_id] |
|
trial.save_parameter_file(command_data) |
|
elif command_type is not None: |
|
raise Exception("unknown command %s" % command_type) |
|
|
|
trial_list = list(trials.values()) |
|
for trial in trial_list: |
|
if trial is not None and trial.is_running(): |
|
idle_last_time = datetime.now() |
|
else: |
|
del trials[trial.id] |
|
|
|
if (datetime.now() - idle_last_time).seconds > idle_timeout_seconds: |
|
nni_log(LogType.Info, "trial runner is idle more than {0} seconds, so exit.".format( |
|
idle_timeout_seconds)) |
|
break |
|
|
|
if args.enable_gpu_collect and (datetime.now() - gpu_refresh_last_time).seconds > gpu_refressh_interval_seconds: |
|
|
|
gpu_info = collect_gpu_usage(args.node_id) |
|
command_channel.send(CommandType.ReportGpuInfo, gpu_info) |
|
gpu_refresh_last_time = datetime.now() |
|
time.sleep(0.5) |
|
except Exception as ex: |
|
traceback.print_exc() |
|
raise ex |
|
finally: |
|
nni_log(LogType.Info, "main_loop exits.") |
|
|
|
trial_list = list(trials.values()) |
|
for trial in trial_list: |
|
trial.kill() |
|
del trials[trial.id] |
|
|
|
for _ in range(10): |
|
if command_channel.sent(): |
|
break |
|
time.sleep(1) |
|
command_channel.close() |
|
|
|
|
|
def trial_runner_help_info(*args): |
|
print('please run --help to see guidance') |
|
|
|
|
|
def check_version(args): |
|
try: |
|
trial_runner_version = pkg_resources.get_distribution('nni').version |
|
except pkg_resources.ResolutionError as err: |
|
|
|
nni_log(LogType.Error, 'Package nni does not exist!') |
|
os._exit(1) |
|
if not args.nni_manager_version: |
|
|
|
nni_log(LogType.Warning, 'Skipping version check!') |
|
else: |
|
try: |
|
command_channel = args.command_channel |
|
trial_runner_version = regular.search(trial_runner_version).group('version') |
|
nni_log(LogType.Info, '{0}: runner_version is {1}'.format(args.node_id, trial_runner_version)) |
|
nni_manager_version = regular.search(args.nni_manager_version).group('version') |
|
nni_log(LogType.Info, '{0}: nni_manager_version is {1}'.format(args.node_id, nni_manager_version)) |
|
log_entry = {} |
|
if trial_runner_version != nni_manager_version: |
|
nni_log(LogType.Error, '{0}: Version does not match!'.format(args.node_id)) |
|
error_message = '{0}: NNIManager version is {1}, Trial runner version is {2}, NNI version does not match!'.format( |
|
args.node_id, nni_manager_version, trial_runner_version) |
|
log_entry['tag'] = 'VCFail' |
|
log_entry['msg'] = error_message |
|
command_channel.send(CommandType.VersionCheck, log_entry) |
|
while not command_channel.sent(): |
|
time.sleep(1) |
|
os._exit(1) |
|
else: |
|
nni_log(LogType.Info, '{0}: Version match!'.format(args.node_id)) |
|
log_entry['tag'] = 'VCSuccess' |
|
command_channel.send(CommandType.VersionCheck, log_entry) |
|
except AttributeError as err: |
|
nni_log(LogType.Error, '{0}: {1}'.format(args.node_id, err)) |
|
|
|
if __name__ == '__main__': |
|
|
|
'''NNI Trial Runner main function''' |
|
PARSER = argparse.ArgumentParser() |
|
PARSER.set_defaults(func=trial_runner_help_info) |
|
PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process') |
|
PARSER.add_argument('--nnimanager_ip', type=str, help='NNI manager rest server IP') |
|
PARSER.add_argument('--nnimanager_port', type=str, help='NNI manager rest server port') |
|
PARSER.add_argument('--nni_manager_version', type=str, help='the nni version transmitted from nniManager') |
|
PARSER.add_argument('--log_collection', type=str, help='set the way to collect log in trial runner') |
|
PARSER.add_argument('--node_count', type=int, help='number of nodes, it determines how to consume command and save code file') |
|
PARSER.add_argument('--job_pid_file', type=str, help='save trial runner process pid') |
|
args, unknown = PARSER.parse_known_args() |
|
|
|
setting_file = "settings.json" |
|
if not os.path.exists(setting_file): |
|
setting_file = "../{}".format(setting_file) |
|
if os.path.exists(setting_file): |
|
with open(setting_file, 'r') as fp: |
|
settings = json.load(fp) |
|
print("setting is {}".format(settings)) |
|
else: |
|
print("not found setting file") |
|
|
|
args.exp_id = settings["experimentId"] |
|
args.platform = settings["platform"] |
|
|
|
args.runner_id = os.path.basename(os.path.realpath(os.path.curdir)) |
|
args.runner_name = "runner_"+args.runner_id |
|
args.enable_gpu_collect = settings["enableGpuCollector"] |
|
args.command_channel = settings["commandChannel"] |
|
|
|
if args.trial_command is None: |
|
args.trial_command = settings["command"] |
|
if args.nnimanager_ip is None: |
|
args.nnimanager_ip = settings["nniManagerIP"] |
|
if args.nnimanager_port is None: |
|
args.nnimanager_port = settings["nniManagerPort"] |
|
if args.nni_manager_version is None: |
|
args.nni_manager_version = settings["nniManagerVersion"] |
|
if args.log_collection is None: |
|
args.log_collection = settings["logCollection"] |
|
if args.node_count is None: |
|
|
|
args.node_count = 1 |
|
|
|
os.environ['NNI_OUTPUT_DIR'] = os.curdir + "/nnioutput" |
|
os.environ['NNI_PLATFORM'] = args.platform |
|
os.environ['NNI_SYS_DIR'] = os.curdir |
|
os.environ['NNI_EXP_ID'] = args.exp_id |
|
os.environ['MULTI_PHASE'] = "true" |
|
os.environ['NNI_TRIAL_JOB_ID'] = "runner" |
|
|
|
from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log |
|
from .trial import Trial |
|
from .file_channel import FileChannel |
|
from .web_channel import WebChannel |
|
from .commands import CommandType |
|
|
|
is_multi_node = args.node_count > 1 |
|
|
|
if (is_multi_node): |
|
|
|
while True: |
|
node_id = random.randint(0, 10000) |
|
unique_check_file_name = "node_%s" % (node_id) |
|
if not os.path.exists(unique_check_file_name): |
|
break |
|
with open(unique_check_file_name, "w") as unique_check_file: |
|
unique_check_file.write("%s" % (int(datetime.now().timestamp() * 1000))) |
|
args.node_id = node_id |
|
else: |
|
|
|
args.node_id = None |
|
|
|
|
|
command_channel = None |
|
if args.command_channel == "file": |
|
command_channel = FileChannel(args) |
|
elif args.command_channel == 'aml': |
|
from .aml_channel import AMLChannel |
|
command_channel = AMLChannel(args) |
|
else: |
|
command_channel = WebChannel(args) |
|
command_channel.open() |
|
|
|
nni_log(LogType.Info, "command channel is {}, actual type is {}".format(args.command_channel, type(command_channel))) |
|
args.command_channel = command_channel |
|
|
|
trial_runner_syslogger = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'runner', |
|
StdOutputType.Stdout, args.log_collection, args.runner_name, command_channel) |
|
sys.stdout = sys.stderr = trial_runner_syslogger |
|
nni_log(LogType.Info, "{}: merged args is {}".format(args.node_id, args)) |
|
|
|
if args.trial_command is None: |
|
nni_log(LogType.Error, "{}: no command is found.".format(args.node_id)) |
|
os._exit(1) |
|
check_version(args) |
|
try: |
|
main_loop(args) |
|
except SystemExit as se: |
|
nni_log(LogType.Info, '{}: NNI trial runner exit with code {}'.format(args.node_id, se.code)) |
|
|
|
|
|
timeout = 10 |
|
while not command_channel.sent() and timeout > 0: |
|
timeout -= 1 |
|
time.sleep(1) |
|
os._exit(se.code) |
|
finally: |
|
if trial_runner_syslogger is not None: |
|
if trial_runner_syslogger.pipeReader is not None: |
|
trial_runner_syslogger.pipeReader.set_process_exit() |
|
trial_runner_syslogger.close() |
|
|
|
|
|
os._exit(0) |
|
|