Spaces:
Sleeping
Sleeping
File size: 3,229 Bytes
d5b7cf9 a3397bd d5b7cf9 48ebd8a a3397bd 48ebd8a d5b7cf9 48ebd8a a3397bd 48ebd8a d5b7cf9 48ebd8a d5b7cf9 48ebd8a d5b7cf9 48ebd8a a3397bd 48ebd8a d5b7cf9 48ebd8a d5b7cf9 48ebd8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import polars as pl
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from planning_ai.common.utils import Paths
from planning_ai.llms.llm import GPT4o
class SummaryEvaluator(BaseModel):
"""Model for evaluating summaries.
Attributes:
score (int): The number of the best summary.
"""
score: int = Field(...)
def load_templates():
"""Loads the comparison and summary templates from files.
Returns:
tuple: A tuple containing the compare template and summary template as strings.
"""
with open("./planning_ai/eval/eval.txt", "r") as f:
compare_template = f.read()
with open("./planning_ai/eval/summary.txt", "r") as f:
summary_template = f.read()
return compare_template, summary_template
def initialize_chains(compare_template, summary_template):
"""Initializes the comparison and summary chains.
Args:
compare_template (str): The template for comparison.
summary_template (str): The template for summary.
Returns:
tuple: A tuple containing the compare chain and summary chain.
"""
SLLM = GPT4o.with_structured_output(SummaryEvaluator, strict=True)
compare_prompt = ChatPromptTemplate([("system", compare_template)])
compare_chain = compare_prompt | SLLM
summary_prompt = ChatPromptTemplate([("system", summary_template)])
summary_chain = summary_prompt | GPT4o | StrOutputParser()
return compare_chain, summary_chain
def process_summaries(compare_chain, summary_chain):
"""Processes summaries by comparing and scoring them.
Args:
compare_chain: The chain used for comparing summaries.
summary_chain: The chain used for generating summaries.
Returns:
polars.DataFrame: A DataFrame containing the original text, summaries, and scores.
"""
original = pl.read_parquet(Paths.STAGING / "gcpt3.parquet").filter(
pl.col("attachments_id").is_null()
)
summaries1 = original[["text", "representations_summary"]].unique()
summaries2 = summaries1[["text"]]
summaries2 = summaries2.with_columns(
pl.col("text")
.map_elements(
lambda x: summary_chain.invoke({"content": x}), return_dtype=pl.String
)
.alias("summary")
)
summaries = summaries1.join(summaries2, on="text")
summaries = summaries.with_columns(
pl.struct(["text", "representations_summary", "summary"])
.map_elements(
lambda x: compare_chain.invoke(
{
"document": x["text"],
"summary_1": x["representations_summary"],
"summary_2": x["summary"],
}
).score,
return_dtype=pl.Int8,
)
.alias("score")
)
return summaries
def main():
compare_template, summary_template = load_templates()
compare_chain, summary_chain = initialize_chains(compare_template, summary_template)
summaries = process_summaries(compare_chain, summary_chain)
summaries.write_parquet(Paths.OUT / "eval.parquet")
if __name__ == "__main__":
main()
|