Spaces:
Running
Running
Refactor message handling to support optional fields.
Browse filesAdded support for handling missing or optional fields when generating prompts, responding to user inputs, and selecting treatment areas or methods. Updated utilities and enhanced prompt templates to improve flexibility and user interaction.
trauma/api/message/ai/engine.py
CHANGED
@@ -50,11 +50,10 @@ async def search_entities(
|
|
50 |
else:
|
51 |
asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
|
52 |
empty_field = retrieve_empty_field_from_entity_data(entity_data)
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
response = await generate_next_question(empty_field, empty_field_instructions, message_history_str)
|
57 |
-
|
58 |
else:
|
59 |
user_messages_str = prepare_user_messages_str(decoded_message, messages)
|
60 |
possible_entity_indexes, search_request = await asyncio.gather(
|
@@ -63,7 +62,9 @@ async def search_entities(
|
|
63 |
)
|
64 |
final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
|
65 |
final_entities_str = prepare_final_entities_str(final_entities)
|
66 |
-
response = await generate_final_response(
|
|
|
|
|
67 |
|
68 |
user_message = MessageModel(chatId=chat.id, author=Author.User, text=decoded_message)
|
69 |
assistant_message = MessageModel(chatId=chat.id, author=Author.Assistant, text=response, entities=final_entities)
|
@@ -95,8 +96,8 @@ async def extend_entities_with_highlights(entities: list[EntityModel], entity_da
|
|
95 |
EntityModelExtended]:
|
96 |
async def choose_closest(entity_: EntityModel) -> tuple:
|
97 |
treatment_area, treatment_method = await asyncio.gather(
|
98 |
-
choose_closest_treatment_area(entity_.treatmentAreas, entity_data
|
99 |
-
choose_closest_treatment_method(entity_.treatmentMethods, entity_data
|
100 |
)
|
101 |
return treatment_area, treatment_method
|
102 |
|
|
|
50 |
else:
|
51 |
asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
|
52 |
empty_field = retrieve_empty_field_from_entity_data(entity_data)
|
53 |
+
empty_field_instructions = pick_empty_field_instructions(empty_field)
|
54 |
+
|
55 |
+
if empty_field == 'age':
|
56 |
response = await generate_next_question(empty_field, empty_field_instructions, message_history_str)
|
|
|
57 |
else:
|
58 |
user_messages_str = prepare_user_messages_str(decoded_message, messages)
|
59 |
possible_entity_indexes, search_request = await asyncio.gather(
|
|
|
62 |
)
|
63 |
final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
|
64 |
final_entities_str = prepare_final_entities_str(final_entities)
|
65 |
+
response = await generate_final_response(
|
66 |
+
final_entities_str, decoded_message, message_history_str, empty_field_instructions, empty_field
|
67 |
+
)
|
68 |
|
69 |
user_message = MessageModel(chatId=chat.id, author=Author.User, text=decoded_message)
|
70 |
assistant_message = MessageModel(chatId=chat.id, author=Author.Assistant, text=response, entities=final_entities)
|
|
|
96 |
EntityModelExtended]:
|
97 |
async def choose_closest(entity_: EntityModel) -> tuple:
|
98 |
treatment_area, treatment_method = await asyncio.gather(
|
99 |
+
choose_closest_treatment_area(entity_.treatmentAreas, entity_data.get('treatmentArea')),
|
100 |
+
choose_closest_treatment_method(entity_.treatmentMethods, entity_data.get('treatmentMethod'))
|
101 |
)
|
102 |
return treatment_area, treatment_method
|
103 |
|
trauma/api/message/ai/openai_request.py
CHANGED
@@ -50,11 +50,18 @@ async def generate_search_request(user_messages_str: str, entity_data: dict):
|
|
50 |
|
51 |
|
52 |
@openai_wrapper(temperature=0.8)
|
53 |
-
async def generate_final_response(
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
else:
|
57 |
-
prompt = TraumaPrompts.
|
58 |
messages = [
|
59 |
{
|
60 |
"role": "system",
|
@@ -76,7 +83,9 @@ async def convert_value_to_embeddings(value: str, dimensions: int = 1536) -> lis
|
|
76 |
|
77 |
|
78 |
@openai_wrapper(is_json=True, return_='result')
|
79 |
-
async def choose_closest_treatment_area(treatment_areas: list[str], treatment_area: str):
|
|
|
|
|
80 |
messages = [
|
81 |
{
|
82 |
"role": "system",
|
@@ -89,7 +98,9 @@ async def choose_closest_treatment_area(treatment_areas: list[str], treatment_ar
|
|
89 |
|
90 |
|
91 |
@openai_wrapper(is_json=True, return_='result')
|
92 |
-
async def choose_closest_treatment_method(treatment_methods: list[str], treatment_method: str):
|
|
|
|
|
93 |
messages = [
|
94 |
{
|
95 |
"role": "system",
|
|
|
50 |
|
51 |
|
52 |
@openai_wrapper(temperature=0.8)
|
53 |
+
async def generate_final_response(
|
54 |
+
final_entities: str, user_message: str, message_history_str: str, empty_field_instructions: str, empty_field: str
|
55 |
+
):
|
56 |
+
if empty_field_instructions:
|
57 |
+
prompt = (TraumaPrompts.generate_not_fully_recommendations
|
58 |
+
.replace("{instructions}", empty_field_instructions)
|
59 |
+
.replace("{empty_field}", empty_field))
|
60 |
+
elif json.loads(final_entities)['klinieken']:
|
61 |
+
prompt = (TraumaPrompts.generate_recommendation_decision
|
62 |
+
.replace("{final_entities}", final_entities))
|
63 |
else:
|
64 |
+
prompt = TraumaPrompts.generate_empty_recommendations
|
65 |
messages = [
|
66 |
{
|
67 |
"role": "system",
|
|
|
83 |
|
84 |
|
85 |
@openai_wrapper(is_json=True, return_='result')
|
86 |
+
async def choose_closest_treatment_area(treatment_areas: list[str], treatment_area: str | None):
|
87 |
+
if not treatment_area:
|
88 |
+
return None
|
89 |
messages = [
|
90 |
{
|
91 |
"role": "system",
|
|
|
98 |
|
99 |
|
100 |
@openai_wrapper(is_json=True, return_='result')
|
101 |
+
async def choose_closest_treatment_method(treatment_methods: list[str], treatment_method: str | None):
|
102 |
+
if not treatment_method:
|
103 |
+
return None
|
104 |
messages = [
|
105 |
{
|
106 |
"role": "system",
|
trauma/api/message/ai/prompts.py
CHANGED
@@ -62,7 +62,7 @@ Je bent een informaticus die gegevens verzamelt over de patiënt en de gewenste
|
|
62 |
{empty_field}
|
63 |
```
|
64 |
|
65 |
-
- {
|
66 |
|
67 |
**Gespreksgeschiedenis**:
|
68 |
```
|
@@ -72,6 +72,31 @@ Je bent een informaticus die gegevens verzamelt over de patiënt en de gewenste
|
|
72 |
## Belangrijke opmerking
|
73 |
|
74 |
- De geformuleerde vraag moet beknopt zijn en een empathische toon hebben."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
generate_search_request = """## Taak
|
76 |
|
77 |
Je moet een beknopte zoekopdracht genereren op basis van de berichten van de gebruiker [berichtgeschiedenis] en de verzamelde patiëntgegevens [patiëntgegevens].
|
|
|
62 |
{empty_field}
|
63 |
```
|
64 |
|
65 |
+
- {instructions}
|
66 |
|
67 |
**Gespreksgeschiedenis**:
|
68 |
```
|
|
|
72 |
## Belangrijke opmerking
|
73 |
|
74 |
- De geformuleerde vraag moet beknopt zijn en een empathische toon hebben."""
|
75 |
+
generate_not_fully_recommendations = """## Taak
|
76 |
+
|
77 |
+
Vertel de gebruiker vriendelijk dat hij de aanbevelingen van de medische instelling kan bekijken op basis van de informatie die hij heeft verstrekt, en stel vervolgens de exacte vraag over het ontbrekende veld (`Ontbrekend veld`), met de uitleg dat dit de zoekresultaten zal verbeteren.
|
78 |
+
|
79 |
+
## Context
|
80 |
+
|
81 |
+
Je bent een informatieve assistent die gegevens verzamelt over de patiënt en de gewenste kliniek. Deze gegevens worden gebruikt door het systeem om een geschikte kliniek aan te bevelen.
|
82 |
+
|
83 |
+
## Gegevens
|
84 |
+
|
85 |
+
**Ontbrekend veld**:
|
86 |
+
```
|
87 |
+
{empty_field}
|
88 |
+
```
|
89 |
+
|
90 |
+
- {instructions}
|
91 |
+
|
92 |
+
**Berichtenhistorie**:
|
93 |
+
```
|
94 |
+
{message_history}
|
95 |
+
```
|
96 |
+
|
97 |
+
## Belangrijke opmerking
|
98 |
+
|
99 |
+
- Verwijs naar de specifieke informatie die de gebruiker al heeft verstrekt (bijvoorbeeld leeftijd en locatie, enz.)."""
|
100 |
generate_search_request = """## Taak
|
101 |
|
102 |
Je moet een beknopte zoekopdracht genereren op basis van de berichten van de gebruiker [berichtgeschiedenis] en de verzamelde patiëntgegevens [patiëntgegevens].
|
trauma/api/message/utils.py
CHANGED
@@ -21,7 +21,7 @@ def transform_messages_to_openai(messages: list[MessageModel]) -> list[dict]:
|
|
21 |
|
22 |
def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
|
23 |
for k, v in entity_data.items():
|
24 |
-
if k
|
25 |
return k
|
26 |
return None
|
27 |
|
@@ -59,7 +59,9 @@ def pick_empty_field_instructions(empty_field: str) -> str:
|
|
59 |
return "Een methode om de ziekte of stoornis te behandelen."
|
60 |
elif empty_field == "location":
|
61 |
return "Stad of adres waar de facility zich bevindt."
|
62 |
-
|
|
|
|
|
63 |
|
64 |
def find_matching_age_group(entity: EntityModel, entity_data: dict) -> AgeGroup:
|
65 |
age_groups = entity.ageGroups
|
|
|
21 |
|
22 |
def retrieve_empty_field_from_entity_data(entity_data: dict) -> str | None:
|
23 |
for k, v in entity_data.items():
|
24 |
+
if k != 'postalCode' and not v:
|
25 |
return k
|
26 |
return None
|
27 |
|
|
|
59 |
return "Een methode om de ziekte of stoornis te behandelen."
|
60 |
elif empty_field == "location":
|
61 |
return "Stad of adres waar de facility zich bevindt."
|
62 |
+
elif empty_field == "treatmentArea":
|
63 |
+
return "Een gebied waar de facility zich bevindt."
|
64 |
+
return None
|
65 |
|
66 |
def find_matching_age_group(entity: EntityModel, entity_data: dict) -> AgeGroup:
|
67 |
age_groups = entity.ageGroups
|
trauma/core/wrappers.py
CHANGED
@@ -53,6 +53,8 @@ def openai_wrapper(
|
|
53 |
@wraps(func)
|
54 |
async def wrapper(*args, **kwargs) -> str:
|
55 |
messages = await func(*args, **kwargs)
|
|
|
|
|
56 |
completion = await settings.OPENAI_CLIENT.chat.completions.create(
|
57 |
messages=messages,
|
58 |
temperature=temperature,
|
|
|
53 |
@wraps(func)
|
54 |
async def wrapper(*args, **kwargs) -> str:
|
55 |
messages = await func(*args, **kwargs)
|
56 |
+
if not isinstance(messages, list):
|
57 |
+
return messages
|
58 |
completion = await settings.OPENAI_CLIENT.chat.completions.create(
|
59 |
messages=messages,
|
60 |
temperature=temperature,
|