"""Add support for llama-cpp-python models to LiteLLM.""" import asyncio import logging import warnings from collections.abc import AsyncIterator, Callable, Iterator from functools import cache from typing import Any, ClassVar, cast import httpx import litellm from litellm import ( # type: ignore[attr-defined] CustomLLM, GenericStreamingChunk, ModelResponse, convert_to_model_response_object, ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from llama_cpp import ( # type: ignore[attr-defined] ChatCompletionRequestMessage, CreateChatCompletionResponse, CreateChatCompletionStreamResponse, Llama, LlamaRAMCache, ) # Reduce the logging level for LiteLLM and flashrank. logging.getLogger("litellm").setLevel(logging.WARNING) logging.getLogger("flashrank").setLevel(logging.WARNING) class LlamaCppPythonLLM(CustomLLM): """A llama-cpp-python provider for LiteLLM. This provider enables using llama-cpp-python models with LiteLLM. The LiteLLM model specification is "llama-cpp-python//@", where n_ctx is an optional parameter that specifies the context size of the model. If n_ctx is not provided or if it's set to 0, the model's default context size is used. Example usage: ```python from litellm import completion response = completion( model="llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@4092", messages=[{"role": "user", "content": "Hello world!"}], # stream=True ) ``` """ # Create a lock to prevent concurrent access to llama-cpp-python models. streaming_lock: ClassVar[asyncio.Lock] = asyncio.Lock() # The set of supported OpenAI parameters is the intersection of [1] and [2]. Not included: # max_completion_tokens, stream_options, n, user, logprobs, top_logprobs, extra_headers. # [1] https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion # [2] https://docs.litellm.ai/docs/completion/input supported_openai_params: ClassVar[list[str]] = [ "functions", # Deprecated "function_call", # Deprecated "tools", "tool_choice", "temperature", "top_p", "top_k", "min_p", "typical_p", "stop", "seed", "response_format", "max_tokens", "presence_penalty", "frequency_penalty", "repeat_penalty", "tfs_z", "mirostat_mode", "mirostat_tau", "mirostat_eta", "logit_bias", ] @staticmethod @cache def llm(model: str, **kwargs: Any) -> Llama: # Drop the llama-cpp-python prefix from the model. repo_id_filename = model.replace("llama-cpp-python/", "") # Convert the LiteLLM model string to repo_id, filename, and n_ctx. repo_id, filename = repo_id_filename.rsplit("/", maxsplit=1) n_ctx = 0 if len(filename_n_ctx := filename.rsplit("@", maxsplit=1)) == 2: # noqa: PLR2004 filename, n_ctx_str = filename_n_ctx n_ctx = int(n_ctx_str) # Load the LLM. with warnings.catch_warnings(): # Filter huggingface_hub warning about HF_TOKEN. warnings.filterwarnings("ignore", category=UserWarning) llm = Llama.from_pretrained( repo_id=repo_id, filename=filename, n_ctx=n_ctx, n_gpu_layers=-1, verbose=False, **kwargs, ) # Enable caching. llm.set_cache(LlamaRAMCache()) # Register the model info with LiteLLM. litellm.register_model( # type: ignore[attr-defined] { model: { "max_tokens": llm.n_ctx(), "max_input_tokens": llm.n_ctx(), "max_output_tokens": None, "input_cost_per_token": 0.0, "output_cost_per_token": 0.0, "output_vector_size": llm.n_embd() if kwargs.get("embedding") else None, "litellm_provider": "llama-cpp-python", "mode": "embedding" if kwargs.get("embedding") else "completion", "supported_openai_params": LlamaCppPythonLLM.supported_openai_params, "supports_function_calling": True, "supports_parallel_function_calling": True, "supports_vision": False, } } ) return llm def completion( # noqa: PLR0913 self, model: str, messages: list[ChatCompletionRequestMessage], api_base: str, custom_prompt_dict: dict[str, Any], model_response: ModelResponse, print_verbose: Callable, # type: ignore[type-arg] encoding: str, api_key: str, logging_obj: Any, optional_params: dict[str, Any], acompletion: Callable | None = None, # type: ignore[type-arg] litellm_params: dict[str, Any] | None = None, logger_fn: Callable | None = None, # type: ignore[type-arg] headers: dict[str, Any] | None = None, timeout: float | httpx.Timeout | None = None, client: HTTPHandler | None = None, ) -> ModelResponse: llm = self.llm(model) llama_cpp_python_params = { k: v for k, v in optional_params.items() if k in self.supported_openai_params } response = cast( CreateChatCompletionResponse, llm.create_chat_completion(messages=messages, **llama_cpp_python_params), ) litellm_model_response: ModelResponse = convert_to_model_response_object( response_object=response, model_response_object=model_response, response_type="completion", stream=False, ) return litellm_model_response def streaming( # noqa: PLR0913 self, model: str, messages: list[ChatCompletionRequestMessage], api_base: str, custom_prompt_dict: dict[str, Any], model_response: ModelResponse, print_verbose: Callable, # type: ignore[type-arg] encoding: str, api_key: str, logging_obj: Any, optional_params: dict[str, Any], acompletion: Callable | None = None, # type: ignore[type-arg] litellm_params: dict[str, Any] | None = None, logger_fn: Callable | None = None, # type: ignore[type-arg] headers: dict[str, Any] | None = None, timeout: float | httpx.Timeout | None = None, client: HTTPHandler | None = None, ) -> Iterator[GenericStreamingChunk]: llm = self.llm(model) llama_cpp_python_params = { k: v for k, v in optional_params.items() if k in self.supported_openai_params } stream = cast( Iterator[CreateChatCompletionStreamResponse], llm.create_chat_completion(messages=messages, **llama_cpp_python_params, stream=True), ) for chunk in stream: choices = chunk.get("choices", []) for choice in choices: text = choice.get("delta", {}).get("content", None) finish_reason = choice.get("finish_reason") litellm_generic_streaming_chunk = GenericStreamingChunk( text=text, # type: ignore[typeddict-item] is_finished=bool(finish_reason), finish_reason=finish_reason, # type: ignore[typeddict-item] usage=None, index=choice.get("index"), # type: ignore[typeddict-item] provider_specific_fields={ "id": chunk.get("id"), "model": chunk.get("model"), "created": chunk.get("created"), "object": chunk.get("object"), }, ) yield litellm_generic_streaming_chunk async def astreaming( # type: ignore[misc,override] # noqa: PLR0913 self, model: str, messages: list[ChatCompletionRequestMessage], api_base: str, custom_prompt_dict: dict[str, Any], model_response: ModelResponse, print_verbose: Callable, # type: ignore[type-arg] encoding: str, api_key: str, logging_obj: Any, optional_params: dict[str, Any], acompletion: Callable | None = None, # type: ignore[type-arg] litellm_params: dict[str, Any] | None = None, logger_fn: Callable | None = None, # type: ignore[type-arg] headers: dict[str, Any] | None = None, timeout: float | httpx.Timeout | None = None, # noqa: ASYNC109 client: AsyncHTTPHandler | None = None, ) -> AsyncIterator[GenericStreamingChunk]: # Start a synchronous stream. stream = self.streaming( model, messages, api_base, custom_prompt_dict, model_response, print_verbose, encoding, api_key, logging_obj, optional_params, acompletion, litellm_params, logger_fn, headers, timeout, ) await asyncio.sleep(0) # Yield control to the event loop after initialising the context. # Wrap the synchronous stream in an asynchronous stream. async with LlamaCppPythonLLM.streaming_lock: for litellm_generic_streaming_chunk in stream: yield litellm_generic_streaming_chunk await asyncio.sleep(0) # Yield control to the event loop after each token. # Register the LlamaCppPythonLLM provider. if not any(provider["provider"] == "llama-cpp-python" for provider in litellm.custom_provider_map): litellm.custom_provider_map.append( {"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()} ) litellm.suppress_debug_info = True