Spaces:
Sleeping
Sleeping
feat: add eval script
Browse filesFormer-commit-id: 9ca03cf7fd6b5cebbb24e023ab238eafd11ed76e [formerly e2d22f446a0f4f0a2817dadb84684eec596c5bc6]
Former-commit-id: 8362e9b7fd02c0f7ce0d4f350e7b7758c97952ef
planning_ai/eval/compare_summaries.py
CHANGED
@@ -6,53 +6,69 @@ from pydantic import BaseModel, Field
|
|
6 |
from planning_ai.common.utils import Paths
|
7 |
from planning_ai.llms.llm import GPT4o
|
8 |
|
9 |
-
with open("./planning_ai/eval/eval.txt", "r") as f:
|
10 |
-
compare_template = f.read()
|
11 |
-
|
12 |
-
with open("./planning_ai/eval/summary.txt", "r") as f:
|
13 |
-
summary_template = f.read()
|
14 |
-
|
15 |
|
16 |
class SummaryEvaluator(BaseModel):
|
17 |
score: int = Field(..., description="The number of the best summary.")
|
18 |
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
|
|
|
|
24 |
|
25 |
-
summary_prompt = ChatPromptTemplate([("system", summary_template)])
|
26 |
-
summary_chain = summary_prompt | GPT4o | StrOutputParser()
|
27 |
|
|
|
28 |
|
29 |
-
original = pl.read_parquet(Paths.STAGING / "gcpt3.parquet").filter(
|
30 |
-
pl.col("attachments_id").is_null()
|
31 |
-
)
|
32 |
-
summaries1 = original[["text", "representations_summary"]].unique().head(20)
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
)
|
40 |
-
|
41 |
-
)
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
)
|
54 |
-
|
55 |
)
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from planning_ai.common.utils import Paths
|
7 |
from planning_ai.llms.llm import GPT4o
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
class SummaryEvaluator(BaseModel):
|
11 |
score: int = Field(..., description="The number of the best summary.")
|
12 |
|
13 |
|
14 |
+
def load_templates():
|
15 |
+
with open("./planning_ai/eval/eval.txt", "r") as f:
|
16 |
+
compare_template = f.read()
|
17 |
+
with open("./planning_ai/eval/summary.txt", "r") as f:
|
18 |
+
summary_template = f.read()
|
19 |
+
return compare_template, summary_template
|
20 |
+
|
21 |
|
22 |
+
def initialize_chains(compare_template, summary_template):
|
23 |
+
SLLM = GPT4o.with_structured_output(SummaryEvaluator, strict=True)
|
24 |
+
compare_prompt = ChatPromptTemplate([("system", compare_template)])
|
25 |
+
compare_chain = compare_prompt | SLLM
|
26 |
|
27 |
+
summary_prompt = ChatPromptTemplate([("system", summary_template)])
|
28 |
+
summary_chain = summary_prompt | GPT4o | StrOutputParser()
|
29 |
|
30 |
+
return compare_chain, summary_chain
|
31 |
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
def process_summaries(compare_chain, summary_chain):
|
34 |
+
original = pl.read_parquet(Paths.STAGING / "gcpt3.parquet").filter(
|
35 |
+
pl.col("attachments_id").is_null()
|
36 |
+
)
|
37 |
+
summaries1 = original[["text", "representations_summary"]].unique()
|
38 |
+
|
39 |
+
summaries2 = summaries1[["text"]]
|
40 |
+
summaries2 = summaries2.with_columns(
|
41 |
+
pl.col("text")
|
42 |
+
.map_elements(
|
43 |
+
lambda x: summary_chain.invoke({"content": x}), return_dtype=pl.String
|
44 |
+
)
|
45 |
+
.alias("summary")
|
46 |
)
|
47 |
+
|
48 |
+
summaries = summaries1.join(summaries2, on="text")
|
49 |
+
summaries = summaries.with_columns(
|
50 |
+
pl.struct(["text", "representations_summary", "summary"])
|
51 |
+
.map_elements(
|
52 |
+
lambda x: compare_chain.invoke(
|
53 |
+
{
|
54 |
+
"document": x["text"],
|
55 |
+
"summary_1": x["representations_summary"],
|
56 |
+
"summary_2": x["summary"],
|
57 |
+
}
|
58 |
+
).score,
|
59 |
+
return_dtype=pl.Int8,
|
60 |
+
)
|
61 |
+
.alias("score")
|
62 |
)
|
63 |
+
return summaries
|
64 |
+
|
65 |
+
|
66 |
+
def main():
|
67 |
+
compare_template, summary_template = load_templates()
|
68 |
+
compare_chain, summary_chain = initialize_chains(compare_template, summary_template)
|
69 |
+
summaries = process_summaries(compare_chain, summary_chain)
|
70 |
+
summaries.write_parquet(Paths.OUT / "eval.parquet")
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
main()
|