# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Reference: - [graphrag](https://github.com/microsoft/graphrag) """ import logging import numbers import re import traceback from dataclasses import dataclass from typing import Any, Mapping, Callable import tiktoken from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str from rag.llm.chat_model import Base as CompletionLLM import networkx as nx from rag.utils import num_tokens_from_string from timeit import default_timer as timer DEFAULT_TUPLE_DELIMITER = "<|>" DEFAULT_RECORD_DELIMITER = "##" DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"] ENTITY_EXTRACTION_MAX_GLEANINGS = 1 @dataclass class GraphExtractionResult: """Unipartite graph extraction result class definition.""" output: nx.Graph source_docs: dict[Any, Any] class GraphExtractor: """Unipartite graph extractor class definition.""" _llm: CompletionLLM _join_descriptions: bool _tuple_delimiter_key: str _record_delimiter_key: str _entity_types_key: str _input_text_key: str _completion_delimiter_key: str _entity_name_key: str _input_descriptions_key: str _extraction_prompt: str _summarization_prompt: str _loop_args: dict[str, Any] _max_gleanings: int _on_error: ErrorHandlerFn def __init__( self, llm_invoker: CompletionLLM, prompt: str | None = None, tuple_delimiter_key: str | None = None, record_delimiter_key: str | None = None, input_text_key: str | None = None, entity_types_key: str | None = None, completion_delimiter_key: str | None = None, join_descriptions=True, encoding_model: str | None = None, max_gleanings: int | None = None, on_error: ErrorHandlerFn | None = None, ): """Init method definition.""" # TODO: streamline construction self._llm = llm_invoker self._join_descriptions = join_descriptions self._input_text_key = input_text_key or "input_text" self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" self._record_delimiter_key = record_delimiter_key or "record_delimiter" self._completion_delimiter_key = ( completion_delimiter_key or "completion_delimiter" ) self._entity_types_key = entity_types_key or "entity_types" self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT self._max_gleanings = ( max_gleanings if max_gleanings is not None else ENTITY_EXTRACTION_MAX_GLEANINGS ) self._on_error = on_error or (lambda _e, _s, _d: None) self.prompt_token_count = num_tokens_from_string(self._extraction_prompt) # Construct the looping arguments encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") yes = encoding.encode("YES") no = encoding.encode("NO") self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} def __call__( self, texts: list[str], prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None ) -> GraphExtractionResult: """Call method definition.""" if prompt_variables is None: prompt_variables = {} all_records: dict[int, str] = {} source_doc_map: dict[int, str] = {} # Wire defaults into the prompt variables prompt_variables = { **prompt_variables, self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) or DEFAULT_TUPLE_DELIMITER, self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) or DEFAULT_RECORD_DELIMITER, self._completion_delimiter_key: prompt_variables.get( self._completion_delimiter_key ) or DEFAULT_COMPLETION_DELIMITER, self._entity_types_key: ",".join( prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES ), } st = timer() total = len(texts) total_token_count = 0 for doc_index, text in enumerate(texts): try: # Invoke the entity extraction result, token_count = self._process_document(text, prompt_variables) source_doc_map[doc_index] = text all_records[doc_index] = result total_token_count += token_count if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}") except Exception as e: logging.exception("error extracting graph") self._on_error( e, traceback.format_exc(), { "doc_index": doc_index, "text": text, }, ) output = self._process_results( all_records, prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER), prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), ) return GraphExtractionResult( output=output, source_docs=source_doc_map, ) def _process_document( self, text: str, prompt_variables: dict[str, str] ) -> str: variables = { **prompt_variables, self._input_text_key: text, } token_count = 0 text = perform_variable_replacements(self._extraction_prompt, variables=variables) gen_conf = {"temperature": 0.3} response = self._llm.chat(text, [], gen_conf) token_count = num_tokens_from_string(text + response) results = response or "" history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}] # Repeat to ensure we maximize entity count for i in range(self._max_gleanings): text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables) history.append({"role": "user", "content": text}) response = self._llm.chat("", history, gen_conf) results += response or "" # if this is the final glean, don't bother updating the continuation flag if i >= self._max_gleanings - 1: break history.append({"role": "assistant", "content": response}) history.append({"role": "user", "content": LOOP_PROMPT}) continuation = self._llm.chat("", history, self._loop_args) if continuation != "YES": break return results, token_count def _process_results( self, results: dict[int, str], tuple_delimiter: str, record_delimiter: str, ) -> nx.Graph: """Parse the result string to create an undirected unipartite graph. Args: - results - dict of results from the extraction chain - tuple_delimiter - delimiter between tuples in an output record, default is '<|>' - record_delimiter - delimiter between records, default is '##' Returns: - output - unipartite graph in graphML format """ graph = nx.Graph() for source_doc_id, extracted_data in results.items(): records = [r.strip() for r in extracted_data.split(record_delimiter)] for record in records: record = re.sub(r"^\(|\)$", "", record.strip()) record_attributes = record.split(tuple_delimiter) if record_attributes[0] == '"entity"' and len(record_attributes) >= 4: # add this record as a node in the G entity_name = clean_str(record_attributes[1].upper()) entity_type = clean_str(record_attributes[2].upper()) entity_description = clean_str(record_attributes[3]) if entity_name in graph.nodes(): node = graph.nodes[entity_name] if self._join_descriptions: node["description"] = "\n".join( list({ *_unpack_descriptions(node), entity_description, }) ) else: if len(entity_description) > len(node["description"]): node["description"] = entity_description node["source_id"] = ", ".join( list({ *_unpack_source_ids(node), str(source_doc_id), }) ) node["entity_type"] = ( entity_type if entity_type != "" else node["entity_type"] ) else: graph.add_node( entity_name, entity_type=entity_type, description=entity_description, source_id=str(source_doc_id), weight=1 ) if ( record_attributes[0] == '"relationship"' and len(record_attributes) >= 5 ): # add this record as edge source = clean_str(record_attributes[1].upper()) target = clean_str(record_attributes[2].upper()) edge_description = clean_str(record_attributes[3]) edge_source_id = clean_str(str(source_doc_id)) weight = ( float(record_attributes[-1]) if isinstance(record_attributes[-1], numbers.Number) else 1.0 ) if source not in graph.nodes(): graph.add_node( source, entity_type="", description="", source_id=edge_source_id, weight=1 ) if target not in graph.nodes(): graph.add_node( target, entity_type="", description="", source_id=edge_source_id, weight=1 ) if graph.has_edge(source, target): edge_data = graph.get_edge_data(source, target) if edge_data is not None: weight += edge_data["weight"] if self._join_descriptions: edge_description = "\n".join( list({ *_unpack_descriptions(edge_data), edge_description, }) ) edge_source_id = ", ".join( list({ *_unpack_source_ids(edge_data), str(source_doc_id), }) ) graph.add_edge( source, target, weight=weight, description=edge_description, source_id=edge_source_id, ) for node_degree in graph.degree: graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) return graph def _unpack_descriptions(data: Mapping) -> list[str]: value = data.get("description", None) return [] if value is None else value.split("\n") def _unpack_source_ids(data: Mapping) -> list[str]: value = data.get("source_id", None) return [] if value is None else value.split(", ")