Spaces:
Paused
Paused
from __future__ import annotations | |
import contextlib | |
import logging | |
import random | |
import torch | |
from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler | |
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker | |
from hivemind.moe.expert_uid import ExpertInfo | |
from torch import nn | |
import src | |
from src.client.remote_block import RemoteTransformerBlock | |
from src.client.remote_sequence_info import RemoteSequenceInfo | |
from src.data_structures import UID_DELIMITER | |
from src.dht_utils import _create_remote_modules_from_infos | |
use_hivemind_log_handler("in_root_logger") | |
logger = get_logger(__file__) | |
class RemoteSequential(nn.Module): | |
""" | |
A sequence of transformer blocks hosted by the swarm. | |
""" | |
def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3): | |
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures") | |
if prefix.endswith(UID_DELIMITER): | |
logger.warning( | |
f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'." | |
f"This will cause {self.__class__.__name__} to look for modules under " | |
f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended." | |
) | |
super().__init__() | |
self.config = config | |
self.dht = dht | |
self.prefix = prefix | |
self.max_retries = max_retries | |
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) | |
block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)) | |
logger.debug(f"Remote block uids: {block_uids}") | |
self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids) | |
def forward(self, inputs: torch.Tensor): | |
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed | |
for block_index in range(self.config.n_layer): | |
for retry_index in range(self.max_retries): | |
try: | |
block = self[block_index] | |
(outputs,) = block(inputs) | |
assert isinstance(outputs, torch.Tensor) | |
assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}" | |
inputs = outputs | |
break | |
except Exception as e: | |
if retry_index == self.max_retries - 1: | |
raise e | |
else: | |
logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True) | |
return inputs | |
def __getitem__(self, block_index: int): | |
assert 0 <= block_index < self.config.n_layer | |
(module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p) | |
return module | |
def __iter__(self): | |
for block_index in range(self.config.n_layer): | |
yield self[block_index] | |
def __len__(self): | |
return len(self.remote_sequence_info) | |
def inference_session(self) -> RemoteSequentialInferenceSession: | |
self.remote_sequence_info.update_() | |
return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p) | |
class RemoteSequentialInferenceSession: | |
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks""" | |
def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P): | |
self.remote_sequence_info = remote_sequence_info | |
self.p2p = p2p | |
self.closed = False | |
self.stack = contextlib.ExitStack() | |
self.active_sessions = [] | |
def __enter__(self): | |
assert not self.closed | |
self.stack.__enter__() | |
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail | |
current_block = 0 | |
while current_block != len(self.remote_sequence_info): | |
candidate_spans = self.remote_sequence_info.spans_containing_block[current_block] | |
chosen_span = random.choice(candidate_spans) # TODO this is a temporary code | |
assert chosen_span.start <= current_block < chosen_span.end | |
# TODO begin throwaway prototype code | |
remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p) | |
_ = remote.info # TODO fix | |
span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end] | |
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id) | |
self.active_sessions.append(remote.inference_session()) | |
self.stack.enter_context(self.active_sessions[-1]) | |
current_block = chosen_span.end | |
# TODO end throwaway prototype code | |
return self | |
def step(self, inputs: torch.Tensor): | |
assert not self.closed | |
for session in self.active_sessions: | |
outputs = session.step(inputs) | |
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}" | |
inputs = outputs | |
return inputs | |
def close(self, *exc_details): | |
"""Finish a given inference session, close the underlying connection""" | |
if not self.closed: | |
self.stack.__exit__(*exc_details or (None, None, None)) | |
self.active_sessions.clear() | |
self.closed = True | |
def __exit__(self, *exc_details): | |
self.close(*exc_details) | |
def __del__(self): | |
self.close() | |