import datetime import time from collections import Counter from contextlib import suppress from dataclasses import asdict from functools import partial import hivemind import numpy as np from multiaddr import Multiaddr from petals.data_structures import UID_DELIMITER, ServerState from petals.utils.dht import compute_spans, get_remote_module_infos import config from data_structures import ModelInfo from p2p_utils import check_reachability_parallel, get_peers_ips, extract_peer_ip_info logger = hivemind.get_logger(__name__) def fetch_health_state(dht: hivemind.DHT) -> dict: start_time = time.perf_counter() bootstrap_peer_ids = [] for addr in config.INITIAL_PEERS: peer_id = hivemind.PeerID.from_base58(Multiaddr(addr)["p2p"]) if peer_id not in bootstrap_peer_ids: bootstrap_peer_ids.append(peer_id) reach_infos = dht.run_coroutine(partial(check_reachability_parallel, bootstrap_peer_ids)) bootstrap_states = ["online" if reach_infos[peer_id]["ok"] else "unreachable" for peer_id in bootstrap_peer_ids] models = config.MODELS[:] model_index = dht.get("_petals.models", latest=True) if model_index is not None and isinstance(model_index.value, dict): official_dht_prefixes = {model.dht_prefix for model in models} custom_models = [] for dht_prefix, model in model_index.value.items(): if dht_prefix in official_dht_prefixes: continue with suppress(TypeError, ValueError): model_info = ModelInfo.from_dict(model.value) if model_info.repository is None or not model_info.repository.startswith("https://huggingface.co/"): continue model_info.dht_prefix = dht_prefix model_info.official = False custom_models.append(model_info) models.extend(sorted(custom_models, key=lambda info: (-info.num_blocks, info.dht_prefix))) logger.info(f"Fetching info for models {[info.name for info in models]}") block_uids = [f"{model.dht_prefix}{UID_DELIMITER}{i}" for model in models for i in range(model.num_blocks)] module_infos = get_remote_module_infos(dht, block_uids, latest=True) model_servers = {} all_servers = {} offset = 0 for model in models: model_servers[model.dht_prefix] = compute_spans( module_infos[offset : offset + model.num_blocks], min_state=ServerState.OFFLINE ) all_servers.update(model_servers[model.dht_prefix]) offset += model.num_blocks online_servers = [peer_id for peer_id, span in all_servers.items() if span.state == ServerState.ONLINE] reach_infos.update(dht.run_coroutine(partial(check_reachability_parallel, online_servers, fetch_info=True))) peers_info = {str(peer.peer_id): {"location": extract_peer_ip_info(str(peer.addrs[0])), "multiaddrs": [str(multiaddr) for multiaddr in peer.addrs]} for peer in dht.run_coroutine(get_peers_ips)} top_contributors = Counter() model_reports = [] for model in models: block_healthy = np.zeros(model.num_blocks, dtype=bool) server_rows = [] for peer_id, span in sorted(model_servers[model.dht_prefix].items()): reachable = reach_infos[peer_id]["ok"] if peer_id in reach_infos else True state = span.state.name.lower() if reachable else "unreachable" if state == "online": block_healthy[span.start : span.end] = True show_public_name = state == "online" and span.length >= 10 if model.official and span.server_info.public_name and show_public_name: top_contributors[span.server_info.public_name] += span.length row = { "short_peer_id": "..." + str(peer_id)[-6:], "peer_id": peer_id, "peer_ip_info": peers_info.get(str(peer_id), "unknown"), "show_public_name": show_public_name, "state": state, "span": span, "adapters": [dict(name=name, short_name=name.split("/")[-1]) for name in span.server_info.adapters], "pings_to_me": { str(origin_id): origin.server_info.next_pings[str(peer_id)] for origin_id, origin in model_servers[model.dht_prefix].items() if origin.server_info.next_pings is not None and str(peer_id) in origin.server_info.next_pings }, } if span.server_info.cache_tokens_left is not None: # We use num_blocks * 2 to account for both keys and values row["cache_tokens_left_per_block"] = span.server_info.cache_tokens_left // (span.length * 2) server_rows.append(row) model_reports.append( dict( name=model.name, short_name=model.short_name, state="healthy" if block_healthy.all() else "broken", server_rows=server_rows, **asdict(model), ) ) reachability_issues = [ dict(peer_id=peer_id, err=info["error"]) for peer_id, info in sorted(reach_infos.items()) if not info["ok"] ] return dict( bootstrap_states=bootstrap_states, top_contributors=top_contributors, model_reports=model_reports, reachability_issues=reachability_issues, last_updated=datetime.datetime.now(datetime.timezone.utc), update_period=config.UPDATE_PERIOD, update_duration=time.perf_counter() - start_time )