Phoenix21 commited on
Commit
ca4603b
·
verified ·
1 Parent(s): f7f11ae

Update chain_problems.py

Browse files
Files changed (1) hide show
  1. chain_problems.py +28 -8
chain_problems.py CHANGED
@@ -1,4 +1,3 @@
1
- # chain_problems.py
2
  import json
3
  import logging
4
  from typing import Dict
@@ -7,6 +6,7 @@ from models import chat_model
7
 
8
  logger = logging.getLogger(__name__)
9
 
 
10
  problem_prompt_template = PromptTemplate(
11
  input_variables=["responses", "internal_report"],
12
  template=(
@@ -15,30 +15,50 @@ problem_prompt_template = PromptTemplate(
15
  "You also have an internal analysis report:\n"
16
  "{internal_report}\n\n"
17
  "From these inputs, determine a 'problem severity percentage' for the user in the following areas: "
18
- "sleep, exercise, stress, and diet. "
19
- "Return your answer in JSON format with keys: sleep_problem, exercise_problem, stress_problem, diet_problem.\n"
 
20
  "Ensure severity percentages are numbers from 0 to 100.\n\n"
21
  "JSON Output:"
22
  )
23
  )
 
24
  problem_chain = LLMChain(llm=chat_model, prompt=problem_prompt_template)
25
 
26
  def analyze_problems_with_chain(responses: Dict[str, str], internal_report: str) -> Dict[str, float]:
27
  responses_str = "\n".join(f"{q}: {a}" for q, a in responses.items())
28
  raw_text = problem_chain.run(responses=responses_str, internal_report=internal_report)
29
  try:
 
30
  start_idx = raw_text.find('{')
31
  end_idx = raw_text.rfind('}') + 1
32
  json_str = raw_text[start_idx:end_idx]
33
  problems = json.loads(json_str)
34
- for key in ["sleep_problem", "exercise_problem", "stress_problem", "diet_problem"]:
 
 
 
 
 
 
 
 
 
 
 
35
  problems.setdefault(key, 0.0)
 
36
  return {k: float(v) for k, v in problems.items()}
37
  except Exception as e:
38
  logger.error(f"Error parsing problem percentages from LLM: {e}")
 
39
  return {
40
- "sleep_problem": 0.0,
41
- "exercise_problem": 0.0,
42
- "stress_problem": 0.0,
43
- "diet_problem": 0.0
 
 
 
 
44
  }
 
 
1
  import json
2
  import logging
3
  from typing import Dict
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
+ # Updated prompt template to include eight new themes
10
  problem_prompt_template = PromptTemplate(
11
  input_variables=["responses", "internal_report"],
12
  template=(
 
15
  "You also have an internal analysis report:\n"
16
  "{internal_report}\n\n"
17
  "From these inputs, determine a 'problem severity percentage' for the user in the following areas: "
18
+ "stress_management, low_therapy, balanced_weight, restless_night, lack_of_motivation, gut_health, anxiety, burnout. "
19
+ "Return your answer in JSON format with keys: stress_management, low_therapy, balanced_weight, restless_night, "
20
+ "lack_of_motivation, gut_health, anxiety, burnout.\n"
21
  "Ensure severity percentages are numbers from 0 to 100.\n\n"
22
  "JSON Output:"
23
  )
24
  )
25
+
26
  problem_chain = LLMChain(llm=chat_model, prompt=problem_prompt_template)
27
 
28
  def analyze_problems_with_chain(responses: Dict[str, str], internal_report: str) -> Dict[str, float]:
29
  responses_str = "\n".join(f"{q}: {a}" for q, a in responses.items())
30
  raw_text = problem_chain.run(responses=responses_str, internal_report=internal_report)
31
  try:
32
+ # Extract JSON from the LLM output
33
  start_idx = raw_text.find('{')
34
  end_idx = raw_text.rfind('}') + 1
35
  json_str = raw_text[start_idx:end_idx]
36
  problems = json.loads(json_str)
37
+
38
+ # Ensure all eight keys are present with default values
39
+ for key in [
40
+ "stress_management",
41
+ "low_therapy",
42
+ "balanced_weight",
43
+ "restless_night",
44
+ "lack_of_motivation",
45
+ "gut_health",
46
+ "anxiety",
47
+ "burnout"
48
+ ]:
49
  problems.setdefault(key, 0.0)
50
+
51
  return {k: float(v) for k, v in problems.items()}
52
  except Exception as e:
53
  logger.error(f"Error parsing problem percentages from LLM: {e}")
54
+ # Return default values for all eight themes in case of an error
55
  return {
56
+ "stress_management": 0.0,
57
+ "low_therapy": 0.0,
58
+ "balanced_weight": 0.0,
59
+ "restless_night": 0.0,
60
+ "lack_of_motivation": 0.0,
61
+ "gut_health": 0.0,
62
+ "anxiety": 0.0,
63
+ "burnout": 0.0
64
  }