File size: 5,540 Bytes
5bdad4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations

import contextlib
import logging
import random

import torch
from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from torch import nn

import src
from src.client.remote_block import RemoteTransformerBlock
from src.client.remote_sequence_info import RemoteSequenceInfo
from src.data_structures import UID_DELIMITER
from src.dht_utils import _create_remote_modules_from_infos

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)


class RemoteSequential(nn.Module):
    """
    A sequence of transformer blocks hosted by the swarm.
    """

    def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
        logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
        if prefix.endswith(UID_DELIMITER):
            logger.warning(
                f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
                f"This will cause {self.__class__.__name__} to look for modules under "
                f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
            )

        super().__init__()
        self.config = config
        self.dht = dht
        self.prefix = prefix
        self.max_retries = max_retries
        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())

        block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))

        logger.debug(f"Remote block uids: {block_uids}")
        self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)

    def forward(self, inputs: torch.Tensor):
        assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
        for block_index in range(self.config.n_layer):
            for retry_index in range(self.max_retries):
                try:
                    block = self[block_index]
                    (outputs,) = block(inputs)
                    assert isinstance(outputs, torch.Tensor)
                    assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
                    inputs = outputs
                    break
                except Exception as e:
                    if retry_index == self.max_retries - 1:
                        raise e
                    else:
                        logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
        return inputs

    def __getitem__(self, block_index: int):
        assert 0 <= block_index < self.config.n_layer
        (module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
        return module

    def __iter__(self):
        for block_index in range(self.config.n_layer):
            yield self[block_index]

    def __len__(self):
        return len(self.remote_sequence_info)

    def inference_session(self) -> RemoteSequentialInferenceSession:
        self.remote_sequence_info.update_()
        return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)


class RemoteSequentialInferenceSession:
    """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""

    def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
        self.remote_sequence_info = remote_sequence_info
        self.p2p = p2p
        self.closed = False
        self.stack = contextlib.ExitStack()
        self.active_sessions = []

    def __enter__(self):
        assert not self.closed
        self.stack.__enter__()
        # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
        current_block = 0
        while current_block != len(self.remote_sequence_info):
            candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
            chosen_span = random.choice(candidate_spans)  # TODO this is a temporary code
            assert chosen_span.start <= current_block < chosen_span.end

            # TODO begin throwaway prototype code
            remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
            _ = remote.info  # TODO fix
            span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
            remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
            self.active_sessions.append(remote.inference_session())
            self.stack.enter_context(self.active_sessions[-1])
            current_block = chosen_span.end
            # TODO end throwaway prototype code

        return self

    def step(self, inputs: torch.Tensor):
        assert not self.closed
        for session in self.active_sessions:
            outputs = session.step(inputs)
            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
            inputs = outputs
        return inputs

    def close(self, *exc_details):
        """Finish a given inference session, close the underlying connection"""
        if not self.closed:
            self.stack.__exit__(*exc_details or (None, None, None))
            self.active_sessions.clear()
            self.closed = True

    def __exit__(self, *exc_details):
        self.close(*exc_details)

    def __del__(self):
        self.close()