File size: 2,664 Bytes
f3e0ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import streamlit as st
from streamlit.logger import get_logger
import requests
import os
from .config import model_name_or_path, BP_THRESHOLD
from transformers import AutoTokenizer
from utils.mongo_utils import new_bp_comparison
from app_config import ENDPOINT_NAMES

logger = get_logger(__name__)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
BP_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["BadPractices"]['name'])
HEADERS = {
    "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
    "Content-Type": "application/json",
}

def bp_predict_message(context, input):
    # context = memory.load_memory_variables({})[memory.memory_key]
    encoding = tokenizer(
        context,
        input,
        truncation="only_first",
        max_length = tokenizer.model_max_length - 2,
    )['input_ids']
    body_request = {
        "inputs": [tokenizer.decode(encoding[1:-1])],
        "params": {
            "top_k": None
        }
    }

    try:
        # Send request to Serving
        logger.debug(f"raw BP body is {body_request}")
        response = requests.post(url=BP_URL, headers=HEADERS, json=body_request)
        if response.status_code == 200:
            response = response.json()['predictions'][0]
            logger.debug(f"Raw BP prediction is {response}")
            return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ]
        else:
            raise Exception(f"Error in response: {response.json()}")
    except Exception as e:
        logger.debug(f"Error in response: {e}")
        st.switch_page("pages/model_loader.py")

def bp_push2db(manual_confirmation=None):
    if manual_confirmation is None:
        if st.session_state.sel_bp == "Advice":
            manual_confirmation = {"is_advice":True, "is_personal_info":False}
        elif st.session_state.sel_bp == "Personal Info":
            manual_confirmation = {"is_advice":False, "is_personal_info":True}
        elif st.session_state.sel_bp == "Advice & Personal Info":
            manual_confirmation = {"is_advice":True, "is_personal_info":True}
        else:
            manual_confirmation = {"is_advice":False, "is_personal_info":False}
    new_bp_comparison(**{
        "client": st.session_state['db_client'],
        "convo_id": st.session_state['convo_id'],
        "model": st.session_state['source'],
        "context": st.session_state["context"],
        "last_message": st.session_state["last_message"],
        "ytrue": manual_confirmation,
        "ypred": {x['label']:x['score'] for x in st.session_state['bp_prediction']},
    })