File size: 2,603 Bytes
8f8aff5
 
1eb5783
 
 
 
 
 
 
8f8aff5
b117341
1eb5783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f8aff5
1eb5783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f8aff5
1eb5783
 
8f8aff5
1eb5783
 
8f8aff5
1eb5783
 
 
 
 
 
 
 
8f8aff5
1eb5783
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import Dict, Any
from llama_index.core import Response
from typing import List
from rag.rag_pipeline import RAGPipeline
from utils.prompts import (
    structured_follow_up_prompt,
    VaccineCoverageVariables,
    StudyCharacteristics,
)


def generate_follow_up_questions(
    rag: RAGPipeline, response: str, query: str, study_name: str
) -> List[str]:
    """
    Generates follow-up questions based on the given RAGPipeline, response, query, and study_name.
    Args:
        rag (RAGPipeline): The RAGPipeline object used for generating follow-up questions.
        response (str): The response to the initial query.
        query (str): The initial query.
        study_name (str): The name of the study.
    Returns:
        List[str]: A list of generated follow-up questions.
    Raises:
        None
    """

    # Determine the study type based on the study_name
    if "Vaccine Coverage" in study_name:
        study_type = "Vaccine Coverage"
        key_variables = list(VaccineCoverageVariables.__annotations__.keys())
    elif "Ebola Virus" in study_name:
        study_type = "Ebola Virus"
        key_variables = [
            "SAMPLE_SIZE",
            "PLASMA_TYPE",
            "DOSAGE",
            "FREQUENCY",
            "SIDE_EFFECTS",
            "VIRAL_LOAD_CHANGE",
            "SURVIVAL_RATE",
        ]
    elif "Gene Xpert" in study_name:
        study_type = "Gene Xpert"
        key_variables = [
            "OBJECTIVE",
            "OUTCOME_MEASURES",
            "SENSITIVITY",
            "SPECIFICITY",
            "COST_COMPARISON",
            "TURNAROUND_TIME",
        ]
    else:
        study_type = "General"
        key_variables = list(StudyCharacteristics.__annotations__.keys())

    # Add key variables to the context
    context = f"Study type: {study_type}\nKey variables to consider: {', '.join(key_variables)}\n\n{response}"

    follow_up_response = rag.query(
        structured_follow_up_prompt.format(
            context_str=context,
            query_str=query,
            response_str=response,
            study_type=study_type,
        )
    )

    questions = follow_up_response.response.strip().split("\n")
    cleaned_questions = []
    for q in questions:
        # Remove leading numbers and periods, and strip whitespace
        cleaned_q = q.split(". ", 1)[-1].strip()
        # Ensure the question ends with a question mark
        if cleaned_q and not cleaned_q.endswith("?"):
            cleaned_q += "?"
        if cleaned_q:
            cleaned_questions.append(f"✨ {cleaned_q}")
    return cleaned_questions[:3]