brestok commited on
Commit
4a1bb29
·
1 Parent(s): 198c726

Refactor message handling to support optional fields.

Browse files

Added support for handling missing or optional fields when generating prompts, responding to user inputs, and selecting treatment areas or methods. Updated utilities and enhanced prompt templates to improve flexibility and user interaction.

trauma/api/message/ai/engine.py CHANGED
@@ -50,11 +50,10 @@ async def search_entities(
50
  else:
51
  asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
52
  empty_field = retrieve_empty_field_from_entity_data(entity_data)
53
-
54
- if empty_field:
55
- empty_field_instructions = pick_empty_field_instructions(empty_field)
56
  response = await generate_next_question(empty_field, empty_field_instructions, message_history_str)
57
-
58
  else:
59
  user_messages_str = prepare_user_messages_str(decoded_message, messages)
60
  possible_entity_indexes, search_request = await asyncio.gather(
@@ -63,7 +62,9 @@ async def search_entities(
63
  )
64
  final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
65
  final_entities_str = prepare_final_entities_str(final_entities)
66
- response = await generate_final_response(final_entities_str, decoded_message, message_history_str)
 
 
67
 
68
  user_message = MessageModel(chatId=chat.id, author=Author.User, text=decoded_message)
69
  assistant_message = MessageModel(chatId=chat.id, author=Author.Assistant, text=response, entities=final_entities)
@@ -95,8 +96,8 @@ async def extend_entities_with_highlights(entities: list[EntityModel], entity_da
95
  EntityModelExtended]:
96
  async def choose_closest(entity_: EntityModel) -> tuple:
97
  treatment_area, treatment_method = await asyncio.gather(
98
- choose_closest_treatment_area(entity_.treatmentAreas, entity_data['treatmentArea']),
99
- choose_closest_treatment_method(entity_.treatmentMethods, entity_data['treatmentMethod'])
100
  )
101
  return treatment_area, treatment_method
102
 
 
50
  else:
51
  asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
52
  empty_field = retrieve_empty_field_from_entity_data(entity_data)
53
+ empty_field_instructions = pick_empty_field_instructions(empty_field)
54
+
55
+ if empty_field == 'age':
56
  response = await generate_next_question(empty_field, empty_field_instructions, message_history_str)
 
57
  else:
58
  user_messages_str = prepare_user_messages_str(decoded_message, messages)
59
  possible_entity_indexes, search_request = await asyncio.gather(
 
62
  )
63
  final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
64
  final_entities_str = prepare_final_entities_str(final_entities)
65
+ response = await generate_final_response(
66
+ final_entities_str, decoded_message, message_history_str, empty_field_instructions, empty_field
67
+ )
68
 
69
  user_message = MessageModel(chatId=chat.id, author=Author.User, text=decoded_message)
70
  assistant_message = MessageModel(chatId=chat.id, author=Author.Assistant, text=response, entities=final_entities)
 
96
  EntityModelExtended]:
97
  async def choose_closest(entity_: EntityModel) -> tuple:
98
  treatment_area, treatment_method = await asyncio.gather(
99
+ choose_closest_treatment_area(entity_.treatmentAreas, entity_data.get('treatmentArea')),
100
+ choose_closest_treatment_method(entity_.treatmentMethods, entity_data.get('treatmentMethod'))
101
  )
102
  return treatment_area, treatment_method
103
 
trauma/api/message/ai/openai_request.py CHANGED
@@ -50,11 +50,18 @@ async def generate_search_request(user_messages_str: str, entity_data: dict):
50
 
51
 
52
  @openai_wrapper(temperature=0.8)
53
- async def generate_final_response(final_entities: str, user_message: str, message_history_str: str):
54
- if not json.loads(final_entities)['klinieken']:
55
- prompt = TraumaPrompts.generate_empty_recommendations
 
 
 
 
 
 
 
56
  else:
57
- prompt = TraumaPrompts.generate_recommendation_decision.replace("{final_entities}", final_entities)
58
  messages = [
59
  {
60
  "role": "system",
@@ -76,7 +83,9 @@ async def convert_value_to_embeddings(value: str, dimensions: int = 1536) -> lis
76
 
77
 
78
  @openai_wrapper(is_json=True, return_='result')
79
- async def choose_closest_treatment_area(treatment_areas: list[str], treatment_area: str):
 
 
80
  messages = [
81
  {
82
  "role": "system",
@@ -89,7 +98,9 @@ async def choose_closest_treatment_area(treatment_areas: list[str], treatment_ar
89
 
90
 
91
  @openai_wrapper(is_json=True, return_='result')
92
- async def choose_closest_treatment_method(treatment_methods: list[str], treatment_method: str):
 
 
93
  messages = [
94
  {
95
  "role": "system",
 
50
 
51
 
52
  @openai_wrapper(temperature=0.8)
53
+ async def generate_final_response(
54
+ final_entities: str, user_message: str, message_history_str: str, empty_field_instructions: str, empty_field: str
55
+ ):
56
+ if empty_field_instructions:
57
+ prompt = (TraumaPrompts.generate_not_fully_recommendations
58
+ .replace("{instructions}", empty_field_instructions)
59
+ .replace("{empty_field}", empty_field))
60
+ elif json.loads(final_entities)['klinieken']:
61
+ prompt = (TraumaPrompts.generate_recommendation_decision
62
+ .replace("{final_entities}", final_entities))
63
  else:
64
+ prompt = TraumaPrompts.generate_empty_recommendations
65
  messages = [
66
  {
67
  "role": "system",
 
83
 
84
 
85
  @openai_wrapper(is_json=True, return_='result')
86
+ async def choose_closest_treatment_area(treatment_areas: list[str], treatment_area: str | None):
87
+ if not treatment_area:
88
+ return None
89
  messages = [
90
  {
91
  "role": "system",
 
98
 
99
 
100
  @openai_wrapper(is_json=True, return_='result')
101
+ async def choose_closest_treatment_method(treatment_methods: list[str], treatment_method: str | None):
102
+ if not treatment_method:
103
+ return None
104
  messages = [
105
  {
106
  "role": "system",
trauma/api/message/ai/prompts.py CHANGED
@@ -62,7 +62,7 @@ Je bent een informaticus die gegevens verzamelt over de patiënt en de gewenste
62
  {empty_field}
63
  ```
64
 
65
- - {instructies}
66
 
67
  **Gespreksgeschiedenis**:
68
  ```
@@ -72,6 +72,31 @@ Je bent een informaticus die gegevens verzamelt over de patiënt en de gewenste
72
  ## Belangrijke opmerking
73
 
74
  - De geformuleerde vraag moet beknopt zijn en een empathische toon hebben."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  generate_search_request = """## Taak
76
 
77
  Je moet een beknopte zoekopdracht genereren op basis van de berichten van de gebruiker [berichtgeschiedenis] en de verzamelde patiëntgegevens [patiëntgegevens].
 
62
  {empty_field}
63
  ```
64
 
65
+ - {instructions}
66
 
67
  **Gespreksgeschiedenis**:
68
  ```
 
72
  ## Belangrijke opmerking
73
 
74
  - De geformuleerde vraag moet beknopt zijn en een empathische toon hebben."""
75
+ generate_not_fully_recommendations = """## Taak
76
+
77
+ Vertel de gebruiker vriendelijk dat hij de aanbevelingen van de medische instelling kan bekijken op basis van de informatie die hij heeft verstrekt, en stel vervolgens de exacte vraag over het ontbrekende veld (`Ontbrekend veld`), met de uitleg dat dit de zoekresultaten zal verbeteren.
78
+
79
+ ## Context
80
+
81
+ Je bent een informatieve assistent die gegevens verzamelt over de patiënt en de gewenste kliniek. Deze gegevens worden gebruikt door het systeem om een geschikte kliniek aan te bevelen.
82
+
83
+ ## Gegevens
84
+
85
+ **Ontbrekend veld**:
86
+ ```
87
+ {empty_field}
88
+ ```
89
+
90
+ - {instructions}
91
+
92
+ **Berichtenhistorie**:
93
+ ```
94
+ {message_history}
95
+ ```
96
+
97
+ ## Belangrijke opmerking
98
+
99
+ - Verwijs naar de specifieke informatie die de gebruiker al heeft verstrekt (bijvoorbeeld leeftijd en locatie, enz.)."""
100
  generate_search_request = """## Taak
101
 
102
  Je moet een beknopte zoekopdracht genereren op basis van de berichten van de gebruiker [berichtgeschiedenis] en de verzamelde patiëntgegevens [patiëntgegevens].
trauma/api/message/utils.py CHANGED
@@ -21,7 +21,7 @@ def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
21
 
22
  def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
23
  for k, v in entity_data.items():
24
- if k not in ('treatmentArea', 'postalCode') and not v:
25
  return k
26
  return None
27
 
@@ -59,7 +59,9 @@ def pick_empty_field_instructions(empty_field: str) -> str:
59
  return "Een methode om de ziekte of stoornis te behandelen."
60
  elif empty_field == "location":
61
  return "Stad of adres waar de facility zich bevindt."
62
-
 
 
63
 
64
  def find_matching_age_group(entity: EntityModel, entity_data: dict) -> AgeGroup:
65
  age_groups = entity.ageGroups
 
21
 
22
  def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
23
  for k, v in entity_data.items():
24
+ if k != 'postalCode' and not v:
25
  return k
26
  return None
27
 
 
59
  return "Een methode om de ziekte of stoornis te behandelen."
60
  elif empty_field == "location":
61
  return "Stad of adres waar de facility zich bevindt."
62
+ elif empty_field == "treatmentArea":
63
+ return "Een gebied waar de facility zich bevindt."
64
+ return None
65
 
66
  def find_matching_age_group(entity: EntityModel, entity_data: dict) -> AgeGroup:
67
  age_groups = entity.ageGroups
trauma/core/wrappers.py CHANGED
@@ -53,6 +53,8 @@ def openai_wrapper(
53
  @wraps(func)
54
  async def wrapper(*args, **kwargs) -> str:
55
  messages = await func(*args, **kwargs)
 
 
56
  completion = await settings.OPENAI_CLIENT.chat.completions.create(
57
  messages=messages,
58
  temperature=temperature,
 
53
  @wraps(func)
54
  async def wrapper(*args, **kwargs) -> str:
55
  messages = await func(*args, **kwargs)
56
+ if not isinstance(messages, list):
57
+ return messages
58
  completion = await settings.OPENAI_CLIENT.chat.completions.create(
59
  messages=messages,
60
  temperature=temperature,