from src.information_retrieval import info_retrieval as ir import logging logger = logging.getLogger(__name__) logging.basicConfig(encoding='utf-8', level=logging.DEBUG) from src.augmentation.prompts import SYSTEM_PROMPT, SUSTAINABILITY_PROMPT, USER_PROMPT def generate_prompt(query, context, template=None): """ Function that generates the prompt given the user query and retrieved context. A specific prompt template will be used if provided, otherwise the default base_prompt template is used. Args: - query: str - context: list[dict] - template: str """ if template: SYS_PROMPT = template else: SYS_PROMPT = SYSTEM_PROMPT formatted_prompt = f"{USER_PROMPT.format(query, context)}" messages = [ {"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt} ] return messages def format_context(context): """ Function that formats the retrieved context in a way that is easy for the LLM to understand. Args: - context: list[dict]; retrieved context """ formatted_context = '' for i, (city, info) in enumerate(context.items()): text = f"Option {i + 1}: {city} is a city in {info['country']}." info_text = f"Here is some information about the city. {info['text']}" attractions_text = "Here are some attractions: " att_flag = 0 restaurants_text = "Here are some places to eat/drink: " rest_flag = 0 hotels_text = "Here are some hotels: " hotel_flag = 0 if len(info['listings']): for listing in info['listings']: if listing['type'] in ['see', 'do', 'go', 'view']: att_flag = 1 attractions_text += f"{listing['name']} ({listing['description']}), " elif listing['type'] in ['eat', 'drink']: rest_flag = 1 restaurants_text += f"{listing['name']} ({listing['description']}), " else: hotel_flag = 1 hotels_text += f"{listing['name']} ({listing['description']}), " # If we add sustainability in the end then it could get truncated because of context window if "sustainability" in info: if info['sustainability']['month'] == 'No data available': sfairness_text = "This city has no sustainability (or s-fairness) score available." else: sfairness_text = f"The sustainability (or s-fairness) score for {city} in {info['sustainability']['month']} is {info['sustainability']['s-fairness']}. \n " text += sfairness_text text += info_text if att_flag: text += f"\n{attractions_text}" if rest_flag: text += f"\n{restaurants_text}" if hotel_flag: text += f"\n{hotels_text}" formatted_context += text + "\n\n " return formatted_context def augment_prompt(query: str, starting_point: str, context: dict, **params: dict): """ Function that accepts the user query as input, obtains relevant documents and augments the prompt with the retrieved context, which can be passed to the LLM. Args: - query: str - context: retrieved context, must be formatted otherwise the LLM cannot understand the nested dictionaries! - sustainability: bool; if true, then the prompt is appended to instruct the LLM to use s-fairness scores while generating the answer - params: key-value parameters to be passed to the get_context function; sets the limit of results and whether to rerank the results """ # what about the cities without s-fairness scores? i.e. they don't have seasonality data updated_query = f"With {starting_point} as the starting point, {query}" prompt_with_sustainability = SUSTAINABILITY_PROMPT # format context formatted_context = format_context(context) if "sustainability" in params["params"] and params["params"]["sustainability"]: prompt = generate_prompt(updated_query, formatted_context, prompt_with_sustainability) else: prompt = generate_prompt(updated_query, formatted_context) return prompt def test(): context_params = { 'limit': 3, 'reranking': 0, 'sustainability': 0 } query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \ "in winter. " # without sustainability context = ir.get_context(query, **context_params) without_sfairness = augment_prompt( query=query, context=context, params=context_params ) # with sustainability context_params.update({'sustainability': 1}) s_context = ir.get_context(query, **context_params) # s_formatted_context = format_context(s_context) with_sfairness = augment_prompt( query=query, context=s_context, params=context_params ) return with_sfairness if __name__ == "__main__": prompt = test() print(prompt)