File size: 5,911 Bytes
8f8aff5
 
1eb5783
 
 
 
 
 
 
7d19cfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f8aff5
b117341
1eb5783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f8aff5
1eb5783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f8aff5
1eb5783
 
8f8aff5
1eb5783
 
8f8aff5
1eb5783
 
 
 
 
 
 
 
8f8aff5
1eb5783
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from typing import Dict, Any
from llama_index.core import Response
from typing import List
from rag.rag_pipeline import RAGPipeline
from utils.prompts import (
    structured_follow_up_prompt,
    VaccineCoverageVariables,
    StudyCharacteristics,
)
import json

def read_study_files(file_path):
    """
    Reads a JSON file and returns the parsed JSON data.

    Args:
        file_path (str): The path to the JSON file to be read.

    Returns:
        dict: The data from the JSON file as a Python dictionary.

    Raises:
        FileNotFoundError: If the file is not found at the provided path.
        json.JSONDecodeError: If the file contents are not valid JSON.
        
    Example:
        Given a JSON file 'study_files.json' with content like:
        {
            "Vaccine Coverage": "data/vaccine_coverage_zotero_items.json",
            "Ebola Virus": "data/ebola_virus_zotero_items.json",
            "Gene Xpert": "data/gene_xpert_zotero_items.json"
        }

        Calling `read_json_file("study_files.json")` will return:
        {
            "Vaccine Coverage": "data/vaccine_coverage_zotero_items.json",
            "Ebola Virus": "data/ebola_virus_zotero_items.json",
            "Gene Xpert": "data/gene_xpert_zotero_items.json"
        }
    """
    try:
        with open(file_path, 'r') as file:
            data = json.load(file)
        return data
    except FileNotFoundError as e:
        raise FileNotFoundError(f"The file at path {file_path} was not found.") from e
    except json.JSONDecodeError as e:
        raise ValueError(f"The file at path {file_path} does not contain valid JSON.") from e


def append_to_study_files(file_path, new_key, new_value):
    """
    Appends a new key-value entry to an existing JSON file.

    Args:
        file_path (str): The path to the JSON file.
        new_key (str): The new key to add to the JSON file.
        new_value (any): The value associated with the new key (can be any valid JSON data type).

    Raises:
        FileNotFoundError: If the file is not found at the provided path.
        json.JSONDecodeError: If the file contents are not valid JSON.
        IOError: If the file cannot be written.

    Example:
        If the file 'study_files.json' initially contains:
        {
            "Vaccine Coverage": "data/vaccine_coverage_zotero_items.json",
            "Ebola Virus": "data/ebola_virus_zotero_items.json"
        }

        Calling `append_to_json_file("study_files.json", "Gene Xpert", "data/gene_xpert_zotero_items.json")`
        will modify the file to:
        {
            "Vaccine Coverage": "data/vaccine_coverage_zotero_items.json",
            "Ebola Virus": "data/ebola_virus_zotero_items.json",
            "Gene Xpert": "data/gene_xpert_zotero_items.json"
        }
    """
    try:
        # Read the existing data from the file
        with open(file_path, 'r') as file:
            data = json.load(file)
        
        # Append the new key-value pair to the dictionary
        data[new_key] = new_value

        # Write the updated data back to the file
        with open(file_path, 'w') as file:
            json.dump(data, file, indent=4)  # indent for pretty printing

    except FileNotFoundError as e:
        raise FileNotFoundError(f"The file at path {file_path} was not found.") from e
    except json.JSONDecodeError as e:
        raise ValueError(f"The file at path {file_path} does not contain valid JSON.") from e
    except IOError as e:
        raise IOError(f"Failed to write to the file at {file_path}.") from e


def generate_follow_up_questions(
    rag: RAGPipeline, response: str, query: str, study_name: str
) -> List[str]:
    """
    Generates follow-up questions based on the given RAGPipeline, response, query, and study_name.
    Args:
        rag (RAGPipeline): The RAGPipeline object used for generating follow-up questions.
        response (str): The response to the initial query.
        query (str): The initial query.
        study_name (str): The name of the study.
    Returns:
        List[str]: A list of generated follow-up questions.
    Raises:
        None
    """

    # Determine the study type based on the study_name
    if "Vaccine Coverage" in study_name:
        study_type = "Vaccine Coverage"
        key_variables = list(VaccineCoverageVariables.__annotations__.keys())
    elif "Ebola Virus" in study_name:
        study_type = "Ebola Virus"
        key_variables = [
            "SAMPLE_SIZE",
            "PLASMA_TYPE",
            "DOSAGE",
            "FREQUENCY",
            "SIDE_EFFECTS",
            "VIRAL_LOAD_CHANGE",
            "SURVIVAL_RATE",
        ]
    elif "Gene Xpert" in study_name:
        study_type = "Gene Xpert"
        key_variables = [
            "OBJECTIVE",
            "OUTCOME_MEASURES",
            "SENSITIVITY",
            "SPECIFICITY",
            "COST_COMPARISON",
            "TURNAROUND_TIME",
        ]
    else:
        study_type = "General"
        key_variables = list(StudyCharacteristics.__annotations__.keys())

    # Add key variables to the context
    context = f"Study type: {study_type}\nKey variables to consider: {', '.join(key_variables)}\n\n{response}"

    follow_up_response = rag.query(
        structured_follow_up_prompt.format(
            context_str=context,
            query_str=query,
            response_str=response,
            study_type=study_type,
        )
    )

    questions = follow_up_response.response.strip().split("\n")
    cleaned_questions = []
    for q in questions:
        # Remove leading numbers and periods, and strip whitespace
        cleaned_q = q.split(". ", 1)[-1].strip()
        # Ensure the question ends with a question mark
        if cleaned_q and not cleaned_q.endswith("?"):
            cleaned_q += "?"
        if cleaned_q:
            cleaned_questions.append(f"✨ {cleaned_q}")
    return cleaned_questions[:3]