File size: 1,662 Bytes
1e91476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3e0ba5
1e91476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3e0ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import re
import logging
from models.custom_parsers import CustomStringOutputParser
from utils.app_utils import get_random_name
from app_config import ENDPOINT_NAMES, SOURCES
from models.databricks.custom_databricks_llm import CustomDatabricksLLM
from langchain.chains import ConversationChain
from langchain.prompts import PromptTemplate

from typing import Any, List, Mapping, Optional, Dict

_DATABRICKS_TEMPLATE_ = """{history}
helper: {input}
texter:"""

def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):

    endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")['name']
    PROMPT = PromptTemplate(
        input_variables=['history', 'input'],
        template=_DATABRICKS_TEMPLATE_
    )
    
    llm = CustomDatabricksLLM(
        endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
        bearer_token=os.environ["DATABRICKS_TOKEN"],
        texter_name=texter_name,
        issue=issue,
        language=language,
        temperature=temperature,
        max_tokens=256,
    )

    llm_chain = ConversationChain(
        llm=llm,
        prompt=PROMPT,
        memory=memory,
        output_parser=CustomStringOutputParser(),
        verbose=True,
    )

    logging.debug(f"loaded Databricks model")
    return llm_chain, None

def cpc_is_alive():
    body_request = {
        "inputs": [""]
    }
    try:
        # Send request to Serving
        response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request, timeout=2)
        if response.status_code == 200:
            return True
        else: return False
    except:
        return False