Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import logging | |
| import os | |
| import threading | |
| from enum import Enum | |
| _logger = logging.getLogger(__name__) | |
| class CommandType(Enum): | |
| # in | |
| Initialize = b'IN' | |
| RequestTrialJobs = b'GE' | |
| ReportMetricData = b'ME' | |
| UpdateSearchSpace = b'SS' | |
| ImportData = b'FD' | |
| AddCustomizedTrialJob = b'AD' | |
| TrialEnd = b'EN' | |
| Terminate = b'TE' | |
| Ping = b'PI' | |
| # out | |
| Initialized = b'ID' | |
| NewTrialJob = b'TR' | |
| SendTrialJobParameter = b'SP' | |
| NoMoreTrialJobs = b'NO' | |
| KillTrialJob = b'KI' | |
| _lock = threading.Lock() | |
| try: | |
| if os.environ.get('NNI_PLATFORM') != 'unittest': | |
| _in_file = open(3, 'rb') | |
| _out_file = open(4, 'wb') | |
| except OSError: | |
| _logger.debug('IPC pipeline not exists') | |
| def send(command, data): | |
| """Send command to Training Service. | |
| command: CommandType object. | |
| data: string payload. | |
| """ | |
| global _lock | |
| try: | |
| _lock.acquire() | |
| data = data.encode('utf8') | |
| msg = b'%b%014d%b' % (command.value, len(data), data) | |
| _logger.debug('Sending command, data: [%s]', msg) | |
| _out_file.write(msg) | |
| _out_file.flush() | |
| finally: | |
| _lock.release() | |
| def receive(): | |
| """Receive a command from Training Service. | |
| Returns a tuple of command (CommandType) and payload (str) | |
| """ | |
| header = _in_file.read(16) | |
| _logger.debug('Received command, header: [%s]', header) | |
| if header is None or len(header) < 16: | |
| # Pipe EOF encountered | |
| _logger.debug('Pipe EOF encountered') | |
| return None, None | |
| length = int(header[2:]) | |
| data = _in_file.read(length) | |
| command = CommandType(header[:2]) | |
| data = data.decode('utf8') | |
| _logger.debug('Received command, data: [%s]', data) | |
| return command, data | |