cjber commited on
Commit
48ebd8a
·
1 Parent(s): ea33740

feat: add eval script

Browse files

Former-commit-id: 9ca03cf7fd6b5cebbb24e023ab238eafd11ed76e [formerly e2d22f446a0f4f0a2817dadb84684eec596c5bc6]
Former-commit-id: 8362e9b7fd02c0f7ce0d4f350e7b7758c97952ef

Files changed (1) hide show
  1. planning_ai/eval/compare_summaries.py +54 -38
planning_ai/eval/compare_summaries.py CHANGED
@@ -6,53 +6,69 @@ from pydantic import BaseModel, Field
6
  from planning_ai.common.utils import Paths
7
  from planning_ai.llms.llm import GPT4o
8
 
9
- with open("./planning_ai/eval/eval.txt", "r") as f:
10
- compare_template = f.read()
11
-
12
- with open("./planning_ai/eval/summary.txt", "r") as f:
13
- summary_template = f.read()
14
-
15
 
16
  class SummaryEvaluator(BaseModel):
17
  score: int = Field(..., description="The number of the best summary.")
18
 
19
 
20
- SLLM = GPT4o.with_structured_output(SummaryEvaluator, strict=True)
 
 
 
 
 
 
21
 
22
- compare_prompt = ChatPromptTemplate([("system", compare_template)])
23
- compare_chain = compare_prompt | SLLM
 
 
24
 
25
- summary_prompt = ChatPromptTemplate([("system", summary_template)])
26
- summary_chain = summary_prompt | GPT4o | StrOutputParser()
27
 
 
28
 
29
- original = pl.read_parquet(Paths.STAGING / "gcpt3.parquet").filter(
30
- pl.col("attachments_id").is_null()
31
- )
32
- summaries1 = original[["text", "representations_summary"]].unique().head(20)
33
 
34
- summaries2 = summaries1[["text"]]
35
- summaries2 = summaries2.with_columns(
36
- pl.col("text")
37
- .map_elements(
38
- lambda x: summary_chain.invoke({"content": x}), return_dtype=pl.String
 
 
 
 
 
 
 
 
39
  )
40
- .alias("summary")
41
- )
42
-
43
- summaries = summaries1.join(summaries2, on="text")
44
- summaries = summaries.with_columns(
45
- pl.struct(["text", "representations_summary", "summary"])
46
- .map_elements(
47
- lambda x: compare_chain.invoke(
48
- {
49
- "document": x["text"],
50
- "summary_1": x["representations_summary"],
51
- "summary_2": x["summary"],
52
- }
53
- ).score,
54
- return_dtype=pl.Int8,
55
  )
56
- .alias("score")
57
- )
58
- summaries["score"].value_counts()
 
 
 
 
 
 
 
 
 
 
6
  from planning_ai.common.utils import Paths
7
  from planning_ai.llms.llm import GPT4o
8
 
 
 
 
 
 
 
9
 
10
  class SummaryEvaluator(BaseModel):
11
  score: int = Field(..., description="The number of the best summary.")
12
 
13
 
14
+ def load_templates():
15
+ with open("./planning_ai/eval/eval.txt", "r") as f:
16
+ compare_template = f.read()
17
+ with open("./planning_ai/eval/summary.txt", "r") as f:
18
+ summary_template = f.read()
19
+ return compare_template, summary_template
20
+
21
 
22
+ def initialize_chains(compare_template, summary_template):
23
+ SLLM = GPT4o.with_structured_output(SummaryEvaluator, strict=True)
24
+ compare_prompt = ChatPromptTemplate([("system", compare_template)])
25
+ compare_chain = compare_prompt | SLLM
26
 
27
+ summary_prompt = ChatPromptTemplate([("system", summary_template)])
28
+ summary_chain = summary_prompt | GPT4o | StrOutputParser()
29
 
30
+ return compare_chain, summary_chain
31
 
 
 
 
 
32
 
33
+ def process_summaries(compare_chain, summary_chain):
34
+ original = pl.read_parquet(Paths.STAGING / "gcpt3.parquet").filter(
35
+ pl.col("attachments_id").is_null()
36
+ )
37
+ summaries1 = original[["text", "representations_summary"]].unique()
38
+
39
+ summaries2 = summaries1[["text"]]
40
+ summaries2 = summaries2.with_columns(
41
+ pl.col("text")
42
+ .map_elements(
43
+ lambda x: summary_chain.invoke({"content": x}), return_dtype=pl.String
44
+ )
45
+ .alias("summary")
46
  )
47
+
48
+ summaries = summaries1.join(summaries2, on="text")
49
+ summaries = summaries.with_columns(
50
+ pl.struct(["text", "representations_summary", "summary"])
51
+ .map_elements(
52
+ lambda x: compare_chain.invoke(
53
+ {
54
+ "document": x["text"],
55
+ "summary_1": x["representations_summary"],
56
+ "summary_2": x["summary"],
57
+ }
58
+ ).score,
59
+ return_dtype=pl.Int8,
60
+ )
61
+ .alias("score")
62
  )
63
+ return summaries
64
+
65
+
66
+ def main():
67
+ compare_template, summary_template = load_templates()
68
+ compare_chain, summary_chain = initialize_chains(compare_template, summary_template)
69
+ summaries = process_summaries(compare_chain, summary_chain)
70
+ summaries.write_parquet(Paths.OUT / "eval.parquet")
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main()