Spaces:
Sleeping
Sleeping
File size: 2,421 Bytes
9ff00d4 bdca921 9ff00d4 20b3b4a 9ff00d4 bdca921 20b3b4a 9ff00d4 20b3b4a 9ff00d4 20b3b4a 9ff00d4 20b3b4a 9ff00d4 20b3b4a 9ff00d4 20b3b4a 9ff00d4 20b3b4a 9ff00d4 20b3b4a 9ff00d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
import json
import requests
import logging
from streamlit.logger import get_logger
from models.custom_parsers import CustomStringOutputParser
from app_config import ENDPOINT_NAMES
from langchain.chains import ConversationChain
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain.prompts import PromptTemplate
from typing import Any, List, Mapping, Optional, Dict
logger = get_logger(__name__)
class DatabricksCustomBizLLM(LLM):
issue:str
language:str
temperature:float = 0.8
max_tokens: int = 128
db_url:str
headers:Mapping[str,str] = {'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}', 'Content-Type': 'application/json'}
@property
def _llm_type(self) -> str:
return "custom_databricks_biz"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
data_ = {'inputs': {
'prompt': [prompt],
'issue': [self.issue],
'language': [self.language],
'temperature': [self.temperature],
'max_tokens': [self.max_tokens],
}}
data_json = json.dumps(data_, allow_nan=True)
response = requests.request(method='POST', headers=self.headers, url=self.db_url, data=data_json)
if response.status_code != 200:
raise Exception(f'Request failed with status {response.status_code}, {response.text}')
return response.json()["predictions"][0]["generated_text"]
_DATABRICKS_TEMPLATE_ = """{history}
helper: {input}
texter:"""
def get_databricks_biz_chain(source, issue, language, memory, temperature=0.8):
PROMPT = PromptTemplate(
input_variables=['history', 'input'],
template=_DATABRICKS_TEMPLATE_
)
llm = DatabricksCustomBizLLM(
issue=issue,
language=language,
temperature=temperature,
max_tokens=256,
db_url = os.environ['DATABRICKS_URL'].format(endpoint_name=ENDPOINT_NAMES.get(source, "conversation_simulator"))
)
llm_chain = ConversationChain(
llm=llm,
prompt=PROMPT,
memory=memory,
output_parser=CustomStringOutputParser()
)
logging.debug(f"loaded Databricks Biz model")
return llm_chain, "helper:" |