cjber commited on
Commit
13686ad
·
1 Parent(s): 8359c66

add hallucination checking cycle

Browse files
planning_ai/chains/fix_chain.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+
3
+ from planning_ai.chains.map_chain import SLLM
4
+
5
+ with open("./planning_ai/chains/prompts/fix_hallucination.txt", "r") as f:
6
+ map_template = f.read()
7
+
8
+ map_prompt = ChatPromptTemplate.from_messages([("system", map_template)])
9
+ fix_chain = map_prompt | SLLM
10
+
11
+ if __name__ == "__main__":
12
+ test_document = """
13
+ The Local Plan proposes a mass development north-west of Cambridge despite marked growth
14
+ in the last twenty years or so following the previous New Settlement Study. In this period,
15
+ the major settlement of Cambourne has been created - now over the projected 3,000 homes and
16
+ Papworth Everard has grown beyond recognition. This in itself is a matter of concern.
17
+ """
18
+
19
+ result = fix_chain.invoke(
20
+ {
21
+ "summary": "This plan is great because they are building a nuclear power plant.",
22
+ "explanation": "The original response does not mention a nuclear power plant, and appears to view the plan negatively",
23
+ "context": test_document,
24
+ }
25
+ )
26
+ print(result)
planning_ai/chains/hallucination_chain.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
+
4
+ from planning_ai.llms.llm import LLM
5
+
6
+ with open("./planning_ai/chains/prompts/hallucination.txt", "r") as f:
7
+ reduce_template = f.read()
8
+
9
+
10
+ class HallucinationChecker(BaseModel):
11
+ """Grade the summary based upon the above criteria."""
12
+
13
+ score: int = Field(..., description="Score for the summary")
14
+ explanation: str = Field(..., description="Explain your reasoning for the score")
15
+
16
+
17
+ SLLM = LLM.with_structured_output(HallucinationChecker)
18
+
19
+ hallucination_prompt = ChatPromptTemplate([("human", reduce_template)])
20
+ hallucination_chain = hallucination_prompt | SLLM
21
+
22
+ if __name__ == "__main__":
23
+ test_document = """
24
+ The Local Plan proposes a mass development north-west of Cambridge despite marked growth
25
+ in the last twenty years or so following the previous New Settlement Study. In this period,
26
+ the major settlement of Cambourne has been created - now over the projected 3,000 homes and
27
+ Papworth Everard has grown beyond recognition. This in itself is a matter of concern.
28
+ """
29
+
30
+ result = hallucination_chain.invoke(
31
+ {
32
+ "summary": "The author fully supports the plan due to the nuclear power plant.",
33
+ "document": test_document,
34
+ }
35
+ )
planning_ai/chains/prompts/fix_hallucination.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are tasked with summarising a response to a planning application proposed by South Cambridgeshire Council. Below, we have provided an **incorrect summary** of the response, along with an **explanation** detailing why the summary is incorrect. Your job is to generate a **correct** summary based solely on the original response provided, avoiding the errors highlighted in the explanation.
2
+
3
+ - **Incorrect Summary**:
4
+ {summary}
5
+
6
+ - **Explanation of Errors**:
7
+ {explanation}
8
+
9
+ - **Original Response**:
10
+ {context}
11
+
12
+ **Your task**: Write a concise and accurate summary of the original response, taking into account the errors highlighted in the explanation.
planning_ai/chains/prompts/hallucination.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are grading text summaries of source documents focused on faithfulness and detection of any hallucinations.
2
+
3
+ Ensure that the Assistant's Summary meets the following criteria:
4
+ (1) it does not contain information outside the score of the source document provided
5
+ (2) the summary should be fully grounded in and based upon the source documents
6
+
7
+ Score:
8
+ A score of 1 means that the Assistant Summary meets the criteria. This is the highest (best) score.
9
+ A score of 0 means that the Assistant Summary does not the criteria. This is the lowest possible score you can give.
10
+
11
+ Explain your reasoning step-by-step to ensure your reasoning and conclusion are correct.
12
+
13
+ Assistant's Summary: {summary}
14
+
15
+ Source document: {document}
planning_ai/graph.py CHANGED
@@ -19,10 +19,21 @@ from planning_ai.states import DocumentState, OverallState
19
  def create_graph():
20
  graph = StateGraph(OverallState)
21
  graph.add_node("generate_summary", generate_summary)
 
 
22
  graph.add_node("generate_final_summary", generate_final_summary)
23
 
24
  graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
25
- graph.add_edge("generate_summary", "generate_final_summary")
 
 
 
 
 
 
 
 
 
26
  graph.add_edge("generate_final_summary", END)
27
 
28
  return graph.compile()
 
19
  def create_graph():
20
  graph = StateGraph(OverallState)
21
  graph.add_node("generate_summary", generate_summary)
22
+ graph.add_node("check_hallucination", check_hallucination)
23
+ graph.add_node("fix_hallucination", fix_hallucination)
24
  graph.add_node("generate_final_summary", generate_final_summary)
25
 
26
  graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
27
+ graph.add_conditional_edges(
28
+ "generate_summary", map_hallucinations, ["check_hallucination"]
29
+ )
30
+ graph.add_conditional_edges(
31
+ "check_hallucination", map_fix_hallucinations, ["fix_hallucination"]
32
+ )
33
+ graph.add_conditional_edges(
34
+ "fix_hallucination", map_hallucinations, ["check_hallucination"]
35
+ )
36
+ graph.add_edge("check_hallucination", "generate_final_summary")
37
  graph.add_edge("generate_final_summary", END)
38
 
39
  return graph.compile()
planning_ai/nodes/hallucination_node.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ from langgraph.constants import Send
4
+
5
+ from planning_ai.chains.fix_chain import fix_chain
6
+ from planning_ai.chains.hallucination_chain import hallucination_chain
7
+ from planning_ai.states import DocumentState, OverallState
8
+
9
+
10
+ def check_hallucination(state: DocumentState):
11
+ print(state["iteration"])
12
+ if state["iteration"] > 5:
13
+ state["iteration"] = -99
14
+ return {"summaries_fixed": [state]}
15
+
16
+ response = hallucination_chain.invoke(
17
+ {"document": state["document"], "summary": state["summary"]}
18
+ )
19
+ if response.score == 1:
20
+ return {"summaries_fixed": [state]}
21
+
22
+ return {
23
+ "hallucinations": [
24
+ {
25
+ "hallucination": response,
26
+ "document": state["document"],
27
+ "summary": state["summary"],
28
+ "iteration": state["iteration"] + 1,
29
+ }
30
+ ]
31
+ }
32
+
33
+
34
+ def map_hallucinations(state: OverallState):
35
+ return [Send("check_hallucination", summary) for summary in state["summaries"]]
36
+
37
+
38
+ def fix_hallucination(state: DocumentState):
39
+ response = fix_chain.invoke(
40
+ {
41
+ "context": state["document"],
42
+ "summary": state["summary"],
43
+ "explanation": state["hallucination"],
44
+ }
45
+ )
46
+ state["summary"] = response
47
+ return {
48
+ "summaries": [
49
+ {
50
+ "document": state["document"],
51
+ "summary": state["summary"],
52
+ "iteration": state["iteration"],
53
+ }
54
+ ]
55
+ }
56
+
57
+
58
+ def map_fix_hallucinations(state: OverallState):
59
+ hallucinations = []
60
+ if "hallucinations" in state:
61
+ hallucinations = [
62
+ hallucination
63
+ for hallucination in state["hallucinations"]
64
+ if hallucination["hallucination"].score != 1
65
+ ]
66
+ return [
67
+ Send("fix_hallucination", hallucination) for hallucination in hallucinations
68
+ ]
planning_ai/nodes/reduce_node.py CHANGED
@@ -3,11 +3,12 @@ from planning_ai.states import OverallState
3
 
4
 
5
  def generate_final_summary(state: OverallState):
6
- __import__("ipdb").set_trace()
7
- # response = reduce_chain.invoke({"context": state["summary_documents"]})
8
- return {
9
- # "final_summary": response,
10
- "summaries": state["summary_documents"],
11
- "hallucinations": state["hallucinations"],
12
- "summary": state["summary_documents"],
13
- }
 
 
3
 
4
 
5
  def generate_final_summary(state: OverallState):
6
+ if len(state["documents"]) == len(state["summaries_fixed"]):
7
+ response = reduce_chain.invoke({"context": state["summaries_fixed"]})
8
+ return {
9
+ "final_summary": response,
10
+ "summaries_fixed": state["summaries_fixed"],
11
+ "summaries": state["summary_documents"],
12
+ "hallucinations": state["hallucinations"],
13
+ "documents": state["documents"],
14
+ }
planning_ai/states.py CHANGED
@@ -1,18 +1,30 @@
1
  import operator
2
  from pathlib import Path
3
- from typing import Annotated, List, TypedDict
4
 
5
  from langchain_core.documents import Document
6
 
 
 
 
7
 
8
  class OverallState(TypedDict):
9
- contents: List[str]
10
- filenames: List[str]
11
  summaries: Annotated[list, operator.add]
12
- collapsed_summaries: List[Document]
 
 
 
 
13
  final_summary: str
14
 
 
15
 
16
- class SummaryState(TypedDict):
17
- content: str
 
 
 
18
  filename: Path
 
 
 
1
  import operator
2
  from pathlib import Path
3
+ from typing import Annotated, List, Optional, TypedDict
4
 
5
  from langchain_core.documents import Document
6
 
7
+ from planning_ai.chains.hallucination_chain import HallucinationChecker
8
+ from planning_ai.chains.map_chain import BriefSummary
9
+
10
 
11
  class OverallState(TypedDict):
12
+ documents: list[str]
 
13
  summaries: Annotated[list, operator.add]
14
+ summaries_fixed: Annotated[list, operator.add]
15
+ hallucinations: Annotated[list, operator.add]
16
+
17
+ filenames: List[Path]
18
+ summary_documents: Annotated[list[Document], operator.add]
19
  final_summary: str
20
 
21
+ iterations: list[int]
22
 
23
+
24
+ class DocumentState(TypedDict):
25
+ document: str
26
+ summary: BriefSummary
27
+ hallucination: HallucinationChecker
28
  filename: Path
29
+
30
+ iteration: int