File size: 5,096 Bytes
f3e0ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import re
import requests
import string
import streamlit as st
from streamlit.logger import get_logger
from app_config import ENDPOINT_NAMES
from models.ta_models.config import NAME2PROMPT, NAME2PROMPT_EXPL, START_INST, END_INST, QUESTIONDEFAULTS, TA_OPTIONS, NAME2QUESTION
import pandas as pd
from langchain_core.messages import AIMessage, HumanMessage
from models.ta_models.ta_prompt_utils import load_context
from utils.mongo_utils import new_convo_scoring_comparison

logger = get_logger(__name__)
TA_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["training_adherence"]['name'])
HEADERS = {
    "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
    "Content-Type": "application/json",
}

def memory2df(memory, conversation_id="convo1234"):
    df = []
    for i, msg in enumerate(memory.buffer_as_messages):
        actor_role = "texter" if type(msg) == AIMessage else "helper" if type(msg) == HumanMessage else None
        if actor_role:
            convo_part = msg.response_metadata.get("phase",None)
            row = {"conversation_id":conversation_id, "message_number":i+1, "actor_role":actor_role, "message":msg.content, "convo_part":convo_part}
            df.append(row)

    return pd.DataFrame(df)

def get_default(question, make_explanation=False):
    return QUESTIONDEFAULTS[question][make_explanation]

def get_context(memory, question, make_explanation=False, **kwargs):
    df = memory2df(memory, **kwargs)
    contexti = load_context(df, question, "messages", "individual").iloc[0]
    if contexti == "":
        return ""
    
    if make_explanation:
        return NAME2PROMPT_EXPL[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
    else:
        return NAME2PROMPT[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)

def post_process_response(full_response, delimiter="\n\n", n=2):
    parts = full_response.split(delimiter)[:n]
    response = extract_response(parts[0])
    logger.debug(f"Response extracted is {response}")
    if len(parts) > 1:
        if len(parts[0]) < len(parts[1]):
            full_response = parts[1]
        else: full_response = parts[0]
    else:
        full_response = parts[0]
    explanation = full_response.lstrip(response).lstrip(string.punctuation)
    explanation = explanation.strip()
    logger.debug(f"Explanation extracted is {explanation}")
    return response, explanation

def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
    full_convo = memory.load_memory_variables({})[memory.memory_key]
    PROMPT = get_context(memory, question, make_explanation=False, **kwargs)
    logger.debug(f"Raw TA prompt is {PROMPT}")
    if PROMPT == "":
        full_response = get_default(question, make_explanation)
        return full_convo, PROMPT, full_response
    
    body_request = {
        "prompt": PROMPT,
        "temperature": 0,
        "max_tokens": 3,
    }

    try:
        # Send request to Serving
        response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
        if response.status_code == 200:
            response = response.json()
        else:
            raise Exception(f"Error in response: {response.json()}")
        full_response = response[0]['choices'][0]['text']
        if not make_explanation:
            return full_convo, PROMPT, full_response
        else:
            extract_response, _ = post_process_response(full_response)
            PROMPT = get_context(memory, question, make_explanation=True, **kwargs)
            PROMPT = PROMPT + f" {extract_response}"
            logger.debug(f"Raw TA prompt for Explanation is {PROMPT}")
            body_request["prompt"] = PROMPT
            body_request["max_tokens"] = 128
            response_expl = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
            if response_expl.status_code == 200:
                response_expl = response_expl.json()
            else:
                raise Exception(f"Error in response: {response_expl.json()}")
            full_response_expl = f"{extract_response} {response_expl[0]['choices'][0]['text']}"
            return full_convo, PROMPT, full_response_expl
    except Exception as e:
        logger.debug(f"Error in response: {e}")
        st.switch_page("pages/model_loader.py")

def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
    """Extract Response from generated answer
    Extract only search strings

    Args:
        x (str): prediction
        default (str, optional): default in case no response founds. Defaults to "N/A".

    Returns:
        str: _description_
    """

    try:
        return re.findall("|".join(TA_OPTIONS), x)[0]
    except Exception:
        return default
    
def ta_push_convo_comparison(ytrue, ypred):
    new_convo_scoring_comparison(**{
        "client": st.session_state['db_client'],
        "convo_id": st.session_state['convo_id'],
        "context": st.session_state["context"] + "\nhelper:" + st.session_state["last_message"],
        "ytrue": ytrue,
        "ypred": ypred,
    })