Spaces:
Runtime error
Runtime error
import os | |
import re | |
import warnings | |
from abc import ABC, abstractmethod | |
from typing import Any, Callable, Dict, List, Mapping, Optional | |
import requests | |
from langchain_core.callbacks import CallbackManagerForLLMRun | |
from langchain_core.language_models import LLM | |
from langchain_core.pydantic_v1 import ( | |
BaseModel, | |
Extra, | |
Field, | |
PrivateAttr, | |
root_validator, | |
validator, | |
) | |
__all__ = ["Databricks"] | |
class _DatabricksClientBase(BaseModel, ABC): | |
"""A base JSON API client that talks to Databricks.""" | |
api_url: str | |
api_token: str | |
def request(self, method: str, url: str, request: Any) -> Any: | |
headers = {"Authorization": f"Bearer {self.api_token}"} | |
response = requests.request( | |
method=method, url=url, headers=headers, json=request | |
) | |
# TODO: error handling and automatic retries | |
if not response.ok: | |
raise ValueError(f"HTTP {response.status_code} error: {response.text}") | |
return response.json() | |
def _get(self, url: str) -> Any: | |
return self.request("GET", url, None) | |
def _post(self, url: str, request: Any) -> Any: | |
return self.request("POST", url, request) | |
def post( | |
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None | |
) -> Any: | |
... | |
def llm(self) -> bool: | |
return False | |
def _transform_completions(response: Dict[str, Any]) -> str: | |
return response["choices"][0]["text"] | |
def _transform_llama2_chat(response: Dict[str, Any]) -> str: | |
return response["candidates"][0]["text"] | |
def _transform_chat(response: Dict[str, Any]) -> str: | |
return response["choices"][0]["message"]["content"] | |
class _DatabricksServingEndpointClient(_DatabricksClientBase): | |
"""An API client that talks to a Databricks serving endpoint.""" | |
host: str | |
endpoint_name: str | |
databricks_uri: str | |
client: Any = None | |
external_or_foundation: bool = False | |
task: Optional[str] = None | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
try: | |
from mlflow.deployments import get_deploy_client | |
self.client = get_deploy_client(self.databricks_uri) | |
except ImportError as e: | |
raise ImportError( | |
"Failed to create the client. " | |
"Please install mlflow with `pip install mlflow`." | |
) from e | |
endpoint = self.client.get_endpoint(self.endpoint_name) | |
self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in ( | |
"external_model", | |
"foundation_model_api", | |
) | |
if self.task is None: | |
self.task = endpoint.get("task") | |
def llm(self) -> bool: | |
return self.task in ("llm/v1/chat", "llm/v1/completions", "llama2/chat") | |
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |
if "api_url" not in values: | |
host = values["host"] | |
endpoint_name = values["endpoint_name"] | |
api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations" | |
values["api_url"] = api_url | |
return values | |
def post( | |
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None | |
) -> Any: | |
if self.external_or_foundation: | |
resp = self.client.predict(endpoint=self.endpoint_name, inputs=request) | |
if transform_output_fn: | |
return transform_output_fn(resp) | |
if self.task == "llm/v1/chat": | |
return _transform_chat(resp) | |
elif self.task == "llm/v1/completions": | |
return _transform_completions(resp) | |
return resp | |
else: | |
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html | |
wrapped_request = {"dataframe_records": [request]} | |
response = self.client.predict( | |
endpoint=self.endpoint_name, inputs=wrapped_request | |
) | |
preds = response["predictions"] | |
# For a single-record query, the result is not a list. | |
pred = preds[0] if isinstance(preds, list) else preds | |
if self.task == "llama2/chat": | |
return _transform_llama2_chat(pred) | |
return transform_output_fn(pred) if transform_output_fn else pred | |
class _DatabricksClusterDriverProxyClient(_DatabricksClientBase): | |
"""An API client that talks to a Databricks cluster driver proxy app.""" | |
host: str | |
cluster_id: str | |
cluster_driver_port: str | |
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |
if "api_url" not in values: | |
host = values["host"] | |
cluster_id = values["cluster_id"] | |
port = values["cluster_driver_port"] | |
api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}" | |
values["api_url"] = api_url | |
return values | |
def post( | |
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None | |
) -> Any: | |
resp = self._post(self.api_url, request) | |
return transform_output_fn(resp) if transform_output_fn else resp | |
def get_repl_context() -> Any: | |
"""Get the notebook REPL context if running inside a Databricks notebook. | |
Returns None otherwise. | |
""" | |
try: | |
from dbruntime.databricks_repl_context import get_context | |
return get_context() | |
except ImportError: | |
raise ImportError( | |
"Cannot access dbruntime, not running inside a Databricks notebook." | |
) | |
def get_default_host() -> str: | |
"""Get the default Databricks workspace hostname. | |
Raises an error if the hostname cannot be automatically determined. | |
""" | |
host = os.getenv("DATABRICKS_HOST") | |
if not host: | |
try: | |
host = get_repl_context().browserHostName | |
if not host: | |
raise ValueError("context doesn't contain browserHostName.") | |
except Exception as e: | |
raise ValueError( | |
"host was not set and cannot be automatically inferred. Set " | |
f"environment variable 'DATABRICKS_HOST'. Received error: {e}" | |
) | |
# TODO: support Databricks CLI profile | |
host = host.lstrip("https://").lstrip("http://").rstrip("/") | |
return host | |
def get_default_api_token() -> str: | |
"""Get the default Databricks personal access token. | |
Raises an error if the token cannot be automatically determined. | |
""" | |
if api_token := os.getenv("DATABRICKS_TOKEN"): | |
return api_token | |
try: | |
api_token = get_repl_context().apiToken | |
if not api_token: | |
raise ValueError("context doesn't contain apiToken.") | |
except Exception as e: | |
raise ValueError( | |
"api_token was not set and cannot be automatically inferred. Set " | |
f"environment variable 'DATABRICKS_TOKEN'. Received error: {e}" | |
) | |
# TODO: support Databricks CLI profile | |
return api_token | |
def _is_hex_string(data: str) -> bool: | |
"""Checks if a data is a valid hexadecimal string using a regular expression.""" | |
if not isinstance(data, str): | |
return False | |
pattern = r"^[0-9a-fA-F]+$" | |
return bool(re.match(pattern, data)) | |
def _load_pickled_fn_from_hex_string( | |
data: str, allow_dangerous_deserialization: Optional[bool] | |
) -> Callable: | |
"""Loads a pickled function from a hexadecimal string.""" | |
if not allow_dangerous_deserialization: | |
raise ValueError( | |
"This code relies on the pickle module. " | |
"You will need to set allow_dangerous_deserialization=True " | |
"if you want to opt-in to allow deserialization of data using pickle." | |
"Data can be compromised by a malicious actor if " | |
"not handled properly to include " | |
"a malicious payload that when deserialized with " | |
"pickle can execute arbitrary code on your machine." | |
) | |
try: | |
import cloudpickle | |
except Exception as e: | |
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}") | |
try: | |
return cloudpickle.loads(bytes.fromhex(data)) | |
except Exception as e: | |
raise ValueError( | |
f"Failed to load the pickled function from a hexadecimal string. Error: {e}" | |
) | |
def _pickle_fn_to_hex_string(fn: Callable) -> str: | |
"""Pickles a function and returns the hexadecimal string.""" | |
try: | |
import cloudpickle | |
except Exception as e: | |
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}") | |
try: | |
return cloudpickle.dumps(fn).hex() | |
except Exception as e: | |
raise ValueError(f"Failed to pickle the function: {e}") | |
class Databricks(LLM): | |
"""Databricks serving endpoint or a cluster driver proxy app for LLM. | |
It supports two endpoint types: | |
* **Serving endpoint** (recommended for both production and development). | |
We assume that an LLM was deployed to a serving endpoint. | |
To wrap it as an LLM you must have "Can Query" permission to the endpoint. | |
Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and | |
``cluster_driver_port``. | |
If the underlying model is a model registered by MLflow, the expected model | |
signature is: | |
* inputs:: | |
[{"name": "prompt", "type": "string"}, | |
{"name": "stop", "type": "list[string]"}] | |
* outputs: ``[{"type": "string"}]`` | |
If the underlying model is an external or foundation model, the response from the | |
endpoint is automatically transformed to the expected format unless | |
``transform_output_fn`` is provided. | |
* **Cluster driver proxy app** (recommended for interactive development). | |
One can load an LLM on a Databricks interactive cluster and start a local HTTP | |
server on the driver node to serve the model at ``/`` using HTTP POST method | |
with JSON input/output. | |
Please use a port number between ``[3000, 8000]`` and let the server listen to | |
the driver IP address or simply ``0.0.0.0`` instead of localhost only. | |
To wrap it as an LLM you must have "Can Attach To" permission to the cluster. | |
Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``. | |
The expected server schema (using JSON schema) is: | |
* inputs:: | |
{"type": "object", | |
"properties": { | |
"prompt": {"type": "string"}, | |
"stop": {"type": "array", "items": {"type": "string"}}}, | |
"required": ["prompt"]}` | |
* outputs: ``{"type": "string"}`` | |
If the endpoint model signature is different or you want to set extra params, | |
you can use `transform_input_fn` and `transform_output_fn` to apply necessary | |
transformations before and after the query. | |
""" | |
host: str = Field(default_factory=get_default_host) | |
"""Databricks workspace hostname. | |
If not provided, the default value is determined by | |
* the ``DATABRICKS_HOST`` environment variable if present, or | |
* the hostname of the current Databricks workspace if running inside | |
a Databricks notebook attached to an interactive cluster in "single user" | |
or "no isolation shared" mode. | |
""" | |
api_token: str = Field(default_factory=get_default_api_token) | |
"""Databricks personal access token. | |
If not provided, the default value is determined by | |
* the ``DATABRICKS_TOKEN`` environment variable if present, or | |
* an automatically generated temporary token if running inside a Databricks | |
notebook attached to an interactive cluster in "single user" or | |
"no isolation shared" mode. | |
""" | |
endpoint_name: Optional[str] = None | |
"""Name of the model serving endpoint. | |
You must specify the endpoint name to connect to a model serving endpoint. | |
You must not set both ``endpoint_name`` and ``cluster_id``. | |
""" | |
cluster_id: Optional[str] = None | |
"""ID of the cluster if connecting to a cluster driver proxy app. | |
If neither ``endpoint_name`` nor ``cluster_id`` is not provided and the code runs | |
inside a Databricks notebook attached to an interactive cluster in "single user" | |
or "no isolation shared" mode, the current cluster ID is used as default. | |
You must not set both ``endpoint_name`` and ``cluster_id``. | |
""" | |
cluster_driver_port: Optional[str] = None | |
"""The port number used by the HTTP server running on the cluster driver node. | |
The server should listen on the driver IP address or simply ``0.0.0.0`` to connect. | |
We recommend the server using a port number between ``[3000, 8000]``. | |
""" | |
model_kwargs: Optional[Dict[str, Any]] = None | |
""" | |
Deprecated. Please use ``extra_params`` instead. Extra parameters to pass to | |
the endpoint. | |
""" | |
transform_input_fn: Optional[Callable] = None | |
"""A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible | |
request object that the endpoint accepts. | |
For example, you can apply a prompt template to the input prompt. | |
""" | |
transform_output_fn: Optional[Callable[..., str]] = None | |
"""A function that transforms the output from the endpoint to the generated text. | |
""" | |
databricks_uri: str = "databricks" | |
"""The databricks URI. Only used when using a serving endpoint.""" | |
temperature: float = 0.0 | |
"""The sampling temperature.""" | |
n: int = 1 | |
"""The number of completion choices to generate.""" | |
stop: Optional[List[str]] = None | |
"""The stop sequence.""" | |
max_tokens: Optional[int] = None | |
"""The maximum number of tokens to generate.""" | |
extra_params: Dict[str, Any] = Field(default_factory=dict) | |
"""Any extra parameters to pass to the endpoint.""" | |
task: Optional[str] = None | |
"""The task of the endpoint. Only used when using a serving endpoint. | |
If not provided, the task is automatically inferred from the endpoint. | |
""" | |
allow_dangerous_deserialization: bool = False | |
"""Whether to allow dangerous deserialization of the data which | |
involves loading data using pickle. | |
If the data has been modified by a malicious actor, it can deliver a | |
malicious payload that results in execution of arbitrary code on the target | |
machine. | |
""" | |
_client: _DatabricksClientBase = PrivateAttr() | |
class Config: | |
extra = Extra.forbid | |
underscore_attrs_are_private = True | |
def _llm_params(self) -> Dict[str, Any]: | |
params: Dict[str, Any] = { | |
"temperature": self.temperature, | |
"n": self.n, | |
} | |
if self.stop: | |
params["stop"] = self.stop | |
if self.max_tokens is not None: | |
params["max_tokens"] = self.max_tokens | |
return params | |
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: | |
if v and values["endpoint_name"]: | |
raise ValueError("Cannot set both endpoint_name and cluster_id.") | |
elif values["endpoint_name"]: | |
return None | |
elif v: | |
return v | |
else: | |
try: | |
if v := get_repl_context().clusterId: | |
return v | |
raise ValueError("Context doesn't contain clusterId.") | |
except Exception as e: | |
raise ValueError( | |
"Neither endpoint_name nor cluster_id was set. " | |
"And the cluster_id cannot be automatically determined. Received" | |
f" error: {e}" | |
) | |
def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: | |
if v and values["endpoint_name"]: | |
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.") | |
elif values["endpoint_name"]: | |
return None | |
elif v is None: | |
raise ValueError( | |
"Must set cluster_driver_port to connect to a cluster driver." | |
) | |
elif int(v) <= 0: | |
raise ValueError(f"Invalid cluster_driver_port: {v}") | |
else: | |
return v | |
def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: | |
if v: | |
assert "prompt" not in v, "model_kwargs must not contain key 'prompt'" | |
assert "stop" not in v, "model_kwargs must not contain key 'stop'" | |
return v | |
def __init__(self, **data: Any): | |
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]): | |
data["transform_input_fn"] = _load_pickled_fn_from_hex_string( | |
data=data["transform_input_fn"], | |
allow_dangerous_deserialization=data.get( | |
"allow_dangerous_deserialization" | |
), | |
) | |
if "transform_output_fn" in data and _is_hex_string( | |
data["transform_output_fn"] | |
): | |
data["transform_output_fn"] = _load_pickled_fn_from_hex_string( | |
data=data["transform_output_fn"], | |
allow_dangerous_deserialization=data.get( | |
"allow_dangerous_deserialization" | |
), | |
) | |
super().__init__(**data) | |
if self.model_kwargs is not None and self.extra_params is not None: | |
raise ValueError("Cannot set both extra_params and extra_params.") | |
elif self.model_kwargs is not None: | |
warnings.warn( | |
"model_kwargs is deprecated. Please use extra_params instead.", | |
DeprecationWarning, | |
) | |
if self.endpoint_name: | |
self._client = _DatabricksServingEndpointClient( | |
host=self.host, | |
api_token=self.api_token, | |
endpoint_name=self.endpoint_name, | |
databricks_uri=self.databricks_uri, | |
task=self.task, | |
) | |
elif self.cluster_id and self.cluster_driver_port: | |
self._client = _DatabricksClusterDriverProxyClient( # type: ignore[call-arg] | |
host=self.host, | |
api_token=self.api_token, | |
cluster_id=self.cluster_id, | |
cluster_driver_port=self.cluster_driver_port, | |
) | |
else: | |
raise ValueError( | |
"Must specify either endpoint_name or cluster_id/cluster_driver_port." | |
) | |
def _default_params(self) -> Dict[str, Any]: | |
"""Return default params.""" | |
return { | |
"host": self.host, | |
# "api_token": self.api_token, # Never save the token | |
"endpoint_name": self.endpoint_name, | |
"cluster_id": self.cluster_id, | |
"cluster_driver_port": self.cluster_driver_port, | |
"databricks_uri": self.databricks_uri, | |
"model_kwargs": self.model_kwargs, | |
"temperature": self.temperature, | |
"n": self.n, | |
"stop": self.stop, | |
"max_tokens": self.max_tokens, | |
"extra_params": self.extra_params, | |
"task": self.task, | |
"transform_input_fn": None | |
if self.transform_input_fn is None | |
else _pickle_fn_to_hex_string(self.transform_input_fn), | |
"transform_output_fn": None | |
if self.transform_output_fn is None | |
else _pickle_fn_to_hex_string(self.transform_output_fn), | |
} | |
def _identifying_params(self) -> Mapping[str, Any]: | |
return self._default_params | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return "databricks" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
"""Queries the LLM endpoint with the given prompt and stop sequence.""" | |
# TODO: support callbacks | |
request: Dict[str, Any] = {"prompt": prompt} | |
if self._client.llm: | |
request.update(self._llm_params) | |
request.update(self.model_kwargs or self.extra_params) | |
request.update(kwargs) | |
if stop: | |
request["stop"] = stop | |
if self.transform_input_fn: | |
request = self.transform_input_fn(**request) | |
return self._client.post(request, transform_output_fn=self.transform_output_fn) | |