# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ import logging from typing import Any, cast, List import html from graspologic.partition import hierarchical_leiden from graspologic.utils import largest_connected_component import networkx as nx from networkx import is_empty log = logging.getLogger(__name__) def _stabilize_graph(graph: nx.Graph) -> nx.Graph: """Ensure an undirected graph with the same relationships will always be read the same way.""" fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() sorted_nodes = graph.nodes(data=True) sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) fixed_graph.add_nodes_from(sorted_nodes) edges = list(graph.edges(data=True)) # If the graph is undirected, we create the edges in a stable way, so we get the same results # for example: # A -> B # in graph theory is the same as # B -> A # in an undirected graph # however, this can lead to downstream issues because sometimes # consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A] # but they base some of their logic on the order of the nodes, so the order ends up being important # so we sort the nodes in the edge in a stable way, so that we always get the same order if not graph.is_directed(): def _sort_source_target(edge): source, target, edge_data = edge if source > target: temp = source source = target target = temp return source, target, edge_data edges = [_sort_source_target(edge) for edge in edges] def _get_edge_key(source: Any, target: Any) -> str: return f"{source} -> {target}" edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) fixed_graph.add_edges_from(edges) return fixed_graph def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph: """Normalize node names.""" node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore return nx.relabel_nodes(graph, node_mapping) def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: """Return the largest connected component of the graph, with nodes and edges sorted in a stable way.""" graph = graph.copy() graph = cast(nx.Graph, largest_connected_component(graph)) graph = normalize_node_names(graph) return _stabilize_graph(graph) def _compute_leiden_communities( graph: nx.Graph | nx.DiGraph, max_cluster_size: int, use_lcc: bool, seed=0xDEADBEEF, ) -> dict[int, dict[str, int]]: """Return Leiden root communities.""" results: dict[int, dict[str, int]] = {} if is_empty(graph): return results if use_lcc: graph = stable_largest_connected_component(graph) community_mapping = hierarchical_leiden( graph, max_cluster_size=max_cluster_size, random_seed=seed ) for partition in community_mapping: results[partition.level] = results.get(partition.level, {}) results[partition.level][partition.node] = partition.cluster return results def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]: """Run method definition.""" max_cluster_size = args.get("max_cluster_size", 12) use_lcc = args.get("use_lcc", True) if args.get("verbose", False): log.info( "Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc ) if not graph.nodes(): return {} node_id_to_community_map = _compute_leiden_communities( graph=graph, max_cluster_size=max_cluster_size, use_lcc=use_lcc, seed=args.get("seed", 0xDEADBEEF), ) levels = args.get("levels") # If they don't pass in levels, use them all if levels is None: levels = sorted(node_id_to_community_map.keys()) results_by_level: dict[int, dict[str, list[str]]] = {} for level in levels: result = {} results_by_level[level] = result for node_id, raw_community_id in node_id_to_community_map[level].items(): community_id = str(raw_community_id) if community_id not in result: result[community_id] = {"weight": 0, "nodes": []} result[community_id]["nodes"].append(node_id) result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1) weights = [comm["weight"] for _, comm in result.items()] if not weights:continue max_weight = max(weights) for _, comm in result.items(): comm["weight"] /= max_weight return results_by_level def add_community_info2graph(graph: nx.Graph, commu_info: dict[str, dict[str, dict]]): for lev, cluster_info in commu_info.items(): for cid, nodes in cluster_info.items(): for n in nodes["nodes"]: if "community" not in graph.nodes[n]: graph.nodes[n]["community"] = {} graph.nodes[n]["community"].update({lev: cid}) def add_community_info2graph(graph: nx.Graph, nodes: List[str], community_title): for n in nodes: if "communities" not in graph.nodes[n]: graph.nodes[n]["communities"] = [] graph.nodes[n]["communities"].append(community_title)