planning-ai / planning_ai /eval /compare_summaries.py
cjber's picture
docs: add docstrings
a3397bd
import polars as pl
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from planning_ai.common.utils import Paths
from planning_ai.llms.llm import GPT4o
class SummaryEvaluator(BaseModel):
"""Model for evaluating summaries.
Attributes:
score (int): The number of the best summary.
"""
score: int = Field(...)
def load_templates():
"""Loads the comparison and summary templates from files.
Returns:
tuple: A tuple containing the compare template and summary template as strings.
"""
with open("./planning_ai/eval/eval.txt", "r") as f:
compare_template = f.read()
with open("./planning_ai/eval/summary.txt", "r") as f:
summary_template = f.read()
return compare_template, summary_template
def initialize_chains(compare_template, summary_template):
"""Initializes the comparison and summary chains.
Args:
compare_template (str): The template for comparison.
summary_template (str): The template for summary.
Returns:
tuple: A tuple containing the compare chain and summary chain.
"""
SLLM = GPT4o.with_structured_output(SummaryEvaluator, strict=True)
compare_prompt = ChatPromptTemplate([("system", compare_template)])
compare_chain = compare_prompt | SLLM
summary_prompt = ChatPromptTemplate([("system", summary_template)])
summary_chain = summary_prompt | GPT4o | StrOutputParser()
return compare_chain, summary_chain
def process_summaries(compare_chain, summary_chain):
"""Processes summaries by comparing and scoring them.
Args:
compare_chain: The chain used for comparing summaries.
summary_chain: The chain used for generating summaries.
Returns:
polars.DataFrame: A DataFrame containing the original text, summaries, and scores.
"""
original = pl.read_parquet(Paths.STAGING / "gcpt3.parquet").filter(
pl.col("attachments_id").is_null()
)
summaries1 = original[["text", "representations_summary"]].unique()
summaries2 = summaries1[["text"]]
summaries2 = summaries2.with_columns(
pl.col("text")
.map_elements(
lambda x: summary_chain.invoke({"content": x}), return_dtype=pl.String
)
.alias("summary")
)
summaries = summaries1.join(summaries2, on="text")
summaries = summaries.with_columns(
pl.struct(["text", "representations_summary", "summary"])
.map_elements(
lambda x: compare_chain.invoke(
{
"document": x["text"],
"summary_1": x["representations_summary"],
"summary_2": x["summary"],
}
).score,
return_dtype=pl.Int8,
)
.alias("score")
)
return summaries
def main():
compare_template, summary_template = load_templates()
compare_chain, summary_chain = initialize_chains(compare_template, summary_template)
summaries = process_summaries(compare_chain, summary_chain)
summaries.write_parquet(Paths.OUT / "eval.parquet")
if __name__ == "__main__":
main()