File size: 4,324 Bytes
cfe8e1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import re
import requests
import string
import streamlit as st
from streamlit.logger import get_logger
from app_config import ENDPOINT_NAMES
from models.ta_models.config import NAME2PROMPT, NAME2PROMPT_EXPL, START_INST, END_INST, QUESTIONDEFAULTS, TA_OPTIONS, NAME2QUESTION
import pandas as pd
from langchain_core.messages import AIMessage, HumanMessage
from models.ta_models.ta_prompt_utils import load_context
from utils.mongo_utils import new_convo_scoring_comparison

logger = get_logger(__name__)
TA_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["training_adherence"]['name'])
HEADERS = {
    "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
    "Content-Type": "application/json",
}

def memory2df(memory, conversation_id="convo1234"):
    df = []
    for i, msg in enumerate(memory.buffer_as_messages):
        actor_role = "texter" if type(msg) == AIMessage else "helper" if type(msg) == HumanMessage else None
        if actor_role:
            convo_part = msg.response_metadata.get("phase",None)
            row = {"conversation_id":conversation_id, "message_number":i+1, "actor_role":actor_role, "message":msg.content, "convo_part":convo_part}
            df.append(row)

    return pd.DataFrame(df)

def get_default(question, make_explanation=False):
    return QUESTIONDEFAULTS[question][make_explanation]

def get_context(memory, question, make_explanation=False, **kwargs):
    df = memory2df(memory, **kwargs)
    contexti = load_context(df, question, "messages", "individual").iloc[0]
    if contexti == "":
        return ""
    
    if make_explanation:
        return NAME2PROMPT_EXPL[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)
    else:
        return NAME2PROMPT[question].format(convo=contexti, start_inst=START_INST, end_inst=END_INST)

def post_process_response(full_response, delimiter="\n\n", n=2):
    parts = full_response.split(delimiter)[:n]
    response = extract_response(parts[0])
    logger.debug(f"Response extracted is {response}")
    if len(parts) > 1:
        if len(parts[0]) < len(parts[1]):
            full_response = parts[1]
        else: full_response = parts[0]
    else:
        full_response = parts[0]
    explanation = full_response.lstrip(response).lstrip(string.punctuation)
    explanation = explanation.strip()
    logger.debug(f"Explanation extracted is {explanation}")
    return response, explanation

def TA_predict_convo(memory, question, make_explanation=False, **kwargs):
    full_convo = memory.load_memory_variables({})[memory.memory_key]
    PROMPT = get_context(memory, question, make_explanation=make_explanation, **kwargs)
    logger.debug(f"Raw TA prompt is {PROMPT}")
    if PROMPT == "":
        full_response = get_default(question, make_explanation)
        # response, explanation = post_process_response(full_response)
        return full_convo, PROMPT, full_response
    
    max_tokens = 128 if make_explanation else 3
    body_request = {
        "prompt": PROMPT,
        "temperature": 0,
        "max_tokens": max_tokens,
    }

    try:
        # Send request to Serving
        response = requests.post(url=TA_URL, headers=HEADERS, json=body_request)
        if response.status_code == 200:
            response = response.json()
        full_response = response[0]['choices'][0]['text']
        logger.debug(f"Raw TA response is {full_response}")
        # response, explanation = post_process_response(full_response)
        return full_convo, PROMPT, full_response
    except:
        pass

def extract_response(x: str, default: str = TA_OPTIONS[0]) -> str:
    """Extract Response from generated answer
    Extract only search strings

    Args:
        x (str): prediction
        default (str, optional): default in case no response founds. Defaults to "N/A".

    Returns:
        str: _description_
    """

    try:
        return re.findall("|".join(TA_OPTIONS), x)[0]
    except Exception:
        return default
    
def ta_push_convo_comparison(ytrue, ypred):
    new_convo_scoring_comparison(**{
        "client": st.session_state['db_client'],
        "convo_id": st.session_state['convo_id'],
        "context": st.session_state["context"] + "\nhelper:" + st.session_state["last_message"],
        "ytrue": ytrue,
        "ypred": ypred,
    })