Spaces:
Paused
Paused
File size: 5,540 Bytes
5bdad4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from __future__ import annotations
import contextlib
import logging
import random
import torch
from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from torch import nn
import src
from src.client.remote_block import RemoteTransformerBlock
from src.client.remote_sequence_info import RemoteSequenceInfo
from src.data_structures import UID_DELIMITER
from src.dht_utils import _create_remote_modules_from_infos
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteSequential(nn.Module):
"""
A sequence of transformer blocks hosted by the swarm.
"""
def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
if prefix.endswith(UID_DELIMITER):
logger.warning(
f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
f"This will cause {self.__class__.__name__} to look for modules under "
f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
)
super().__init__()
self.config = config
self.dht = dht
self.prefix = prefix
self.max_retries = max_retries
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
logger.debug(f"Remote block uids: {block_uids}")
self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
def forward(self, inputs: torch.Tensor):
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
for block_index in range(self.config.n_layer):
for retry_index in range(self.max_retries):
try:
block = self[block_index]
(outputs,) = block(inputs)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
inputs = outputs
break
except Exception as e:
if retry_index == self.max_retries - 1:
raise e
else:
logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
return inputs
def __getitem__(self, block_index: int):
assert 0 <= block_index < self.config.n_layer
(module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
return module
def __iter__(self):
for block_index in range(self.config.n_layer):
yield self[block_index]
def __len__(self):
return len(self.remote_sequence_info)
def inference_session(self) -> RemoteSequentialInferenceSession:
self.remote_sequence_info.update_()
return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
class RemoteSequentialInferenceSession:
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
self.remote_sequence_info = remote_sequence_info
self.p2p = p2p
self.closed = False
self.stack = contextlib.ExitStack()
self.active_sessions = []
def __enter__(self):
assert not self.closed
self.stack.__enter__()
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
current_block = 0
while current_block != len(self.remote_sequence_info):
candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
assert chosen_span.start <= current_block < chosen_span.end
# TODO begin throwaway prototype code
remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
_ = remote.info # TODO fix
span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
self.active_sessions.append(remote.inference_session())
self.stack.enter_context(self.active_sessions[-1])
current_block = chosen_span.end
# TODO end throwaway prototype code
return self
def step(self, inputs: torch.Tensor):
assert not self.closed
for session in self.active_sessions:
outputs = session.step(inputs)
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
inputs = outputs
return inputs
def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection"""
if not self.closed:
self.stack.__exit__(*exc_details or (None, None, None))
self.active_sessions.clear()
self.closed = True
def __exit__(self, *exc_details):
self.close(*exc_details)
def __del__(self):
self.close()
|