from typing import Any, Dict from enum import Enum #from langchain_community.chat_models.huggingface import ChatHuggingFace #from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain_core import pydantic_v1 from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.utils import get_from_dict_or_env from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint from langchain_openai import ChatOpenAI class LLMBackends(Enum): """LLMBackends Enum for LLMBackends. """ VLLM = "VLLM" HFChat = "HFChat" Fireworks = "Fireworks" class LazyChatHuggingFace(ChatHuggingFace): """LazyChatHuggingFace""" def __init__(self, **kwargs: Any): BaseChatModel.__init__(self, **kwargs) from transformers import AutoTokenizer if not self.model_id: self._resolve_model_id() self.tokenizer = ( AutoTokenizer.from_pretrained(self.model_id) if self.tokenizer is None else self.tokenizer ) class LazyHuggingFaceEndpoint(HuggingFaceEndpoint): """LazyHuggingFaceEndpoint""" # We're using a lazy endpoint to avoid logging in with hf_token, # which might in fact be a hf_oauth token that does only permit inference, # not logging in. @pydantic_v1.root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: return super().build_extra(values) @pydantic_v1.root_validator() def validate_environment(cls, values: Dict) -> Dict: # noqa: UP006, N805 """Validate that package is installed and that the API token is valid.""" try: from huggingface_hub import AsyncInferenceClient, InferenceClient except ImportError: msg = ( "Could not import huggingface_hub python package. " "Please install it with `pip install huggingface_hub`." ) raise ImportError(msg) # noqa: B904 huggingfacehub_api_token = get_from_dict_or_env( values, "huggingfacehub_api_token", "HF_TOKEN" ) values["client"] = InferenceClient( model=values["model"], timeout=values["timeout"], token=huggingfacehub_api_token, **values["server_kwargs"], ) values["async_client"] = AsyncInferenceClient( model=values["model"], timeout=values["timeout"], token=huggingfacehub_api_token, **values["server_kwargs"], ) return values def get_chat_model_wrapper( model_id: str, inference_server_url: str, token: str, backend: str = "HFChat", **model_init_kwargs ): backend = LLMBackends(backend) if backend == LLMBackends.HFChat: # llm = LazyHuggingFaceEndpoint( # endpoint_url=inference_server_url, # task="text-generation", # huggingfacehub_api_token=token, # **model_init_kwargs, # ) llm = LazyHuggingFaceEndpoint( repo_id=model_id, task="text-generation", huggingfacehub_api_token=token, **model_init_kwargs, ) from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) chat_model = LazyChatHuggingFace(llm=llm, model_id=model_id, tokenizer=tokenizer) elif backend in [LLMBackends.VLLM, LLMBackends.Fireworks]: chat_model = ChatOpenAI( model=model_id, openai_api_base=inference_server_url, # type: ignore openai_api_key=token, # type: ignore **model_init_kwargs, ) else: raise ValueError(f"Backend {backend} not supported") return chat_model