DAMHelper / prompts /prompts_manager.py
enricorampazzo's picture
now taking into account saved details when prompting and parsing the answer of the LLM
88f9105
raw
history blame
No virus
3.52 kB
from pathlib import Path
from enums import Questions
from local_storage.entities import SavedDetails, PersonalDetails, ContractorDetails, LocationDetails
from utils.date_utils import get_today_date_as_dd_mm_yyyy
class PromptsManager:
def __init__(self, work_categories: dict[str, str] = None):
self.work_categories = work_categories
base_path = Path(__file__).parent
with open(Path(base_path, "system_prompt.txt")) as sysprompt_file:
self.system_prompt: str = sysprompt_file.read()
with open(Path(base_path, "questions.txt")) as questions_file:
self.questions: list[str] = questions_file.readlines()
with open(Path(base_path, "verification_prompt2.txt")) as verification_prompt_file:
verification_prompt = verification_prompt_file.read()
todays_date = get_today_date_as_dd_mm_yyyy()
verification_prompt = verification_prompt.replace("{today}", todays_date)
self.verification_prompt: str = verification_prompt
self.verification_prompt_questions = self.verification_prompt.split("\n")
def verify_user_input_prompt(self, user_prompt, exclude_questions_group: list[SavedDetails] = None) -> str:
prompt = f"""
Using only this information \n {user_prompt} \n answer the following questions, for each question that you cannot answer just answer 'null'.
Put each answer in a new line, keep the answer brief
and maintain the order in which the questions are asked. Do not add any preamble:
"""
skip_questions = self.get_questions(exclude_questions_group)
questions = [q for idx, q in enumerate(self.verification_prompt_questions) if idx not in skip_questions]
return prompt + "\n".join(questions)
def get_work_category(self, work_description: str) -> str:
return (
f"The work to do is {work_description}: choose the most accurate categories among the following:"
f"{', '.join(self.work_categories.values())}\n"
f"Only return the categories, separated by a semicolon")
@staticmethod
def questions_to_field_labels():
return {
Questions.FULL_NAME: "Full name", Questions.WORK_TO_DO: "Work to do", Questions.COMMUNITY: "Community",
Questions.BUILDING: "Building name", Questions.UNIT_APT_NUMBER: "Unit/apartment number",
Questions.OWNER_OR_TENANT: "Owner/Tenant", Questions.START_DATE: "Start date",
Questions.END_DATE: "End date", Questions.CONTACT_NUMBER: "Your contact number",
Questions.COMPANY_NAME: "Contractor company name", Questions.COMPANY_EMAIL: "Contracting company email",
Questions.COMPANY_NUMBER: "Contracting company contact number", Questions.YOUR_EMAIL: "Your email"
}
@staticmethod
def get_questions(details: list[int]):
to_skip: list[int] = []
for d in details:
if isinstance(d, PersonalDetails):
to_skip.extend([Questions.FULL_NAME.value, Questions.CONTACT_NUMBER.value, Questions.YOUR_EMAIL.value])
if isinstance(d, ContractorDetails):
to_skip.extend([Questions.COMPANY_NAME.value, Questions.COMPANY_NUMBER.value, Questions.COMPANY_EMAIL.value])
if isinstance(d, LocationDetails):
to_skip.extend([Questions.OWNER_OR_TENANT.value, Questions.COMMUNITY.value, Questions.BUILDING.value, Questions.UNIT_APT_NUMBER.value])
return to_skip