Gregor Betz commited on
Commit
5c840c4
1 Parent(s): 58047c7

LazyHfEndpoint validator

Browse files
Files changed (2) hide show
  1. backend/models.py +14 -8
  2. requirements.txt +4 -3
backend/models.py CHANGED
@@ -46,6 +46,10 @@ class LazyHuggingFaceEndpoint(HuggingFaceEndpoint):
46
  # which might in fact be a hf_oauth token that does only permit inference,
47
  # not logging in.
48
 
 
 
 
 
49
  @pydantic_v1.root_validator()
50
  def validate_environment(cls, values: Dict) -> Dict: # noqa: UP006, N805
51
  """Validate that package is installed and that the API token is valid."""
@@ -83,7 +87,7 @@ def get_chat_model_wrapper(
83
  model_id: str,
84
  inference_server_url: str,
85
  token: str,
86
- backend: str = LLMBackends.HFChat,
87
  **model_init_kwargs
88
  ):
89
 
@@ -97,18 +101,17 @@ def get_chat_model_wrapper(
97
  # **model_init_kwargs,
98
  # )
99
 
100
- # from transformers import AutoTokenizer
101
-
102
- # tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
103
- # chat_model = LazyChatHuggingFace(llm=llm, model_id=model_id, tokenizer=tokenizer)
104
-
105
- llm = HuggingFaceEndpoint(
106
  repo_id=model_id,
107
  task="text-generation",
108
  huggingfacehub_api_token=token,
109
  **model_init_kwargs,
110
  )
111
- chat_model = ChatHuggingFace(llm=llm)
 
 
 
 
112
 
113
  elif backend in [LLMBackends.VLLM, LLMBackends.Fireworks]:
114
  chat_model = ChatOpenAI(
@@ -118,4 +121,7 @@ def get_chat_model_wrapper(
118
  **model_init_kwargs,
119
  )
120
 
 
 
 
121
  return chat_model
 
46
  # which might in fact be a hf_oauth token that does only permit inference,
47
  # not logging in.
48
 
49
+ @pydantic_v1.root_validator(pre=True)
50
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
51
+ return super().build_extra(values)
52
+
53
  @pydantic_v1.root_validator()
54
  def validate_environment(cls, values: Dict) -> Dict: # noqa: UP006, N805
55
  """Validate that package is installed and that the API token is valid."""
 
87
  model_id: str,
88
  inference_server_url: str,
89
  token: str,
90
+ backend: str = "HFChat",
91
  **model_init_kwargs
92
  ):
93
 
 
101
  # **model_init_kwargs,
102
  # )
103
 
104
+ llm = LazyHuggingFaceEndpoint(
 
 
 
 
 
105
  repo_id=model_id,
106
  task="text-generation",
107
  huggingfacehub_api_token=token,
108
  **model_init_kwargs,
109
  )
110
+
111
+ from transformers import AutoTokenizer
112
+
113
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
114
+ chat_model = LazyChatHuggingFace(llm=llm, model_id=model_id, tokenizer=tokenizer)
115
 
116
  elif backend in [LLMBackends.VLLM, LLMBackends.Fireworks]:
117
  chat_model = ChatOpenAI(
 
121
  **model_init_kwargs,
122
  )
123
 
124
+ else:
125
+ raise ValueError(f"Backend {backend} not supported")
126
+
127
  return chat_model
requirements.txt CHANGED
@@ -2,9 +2,10 @@ gradio==4.37.2
2
  aiohttp
3
  datasets
4
  huggingface_hub
5
- langchain
6
- langchain_huggingface
7
- langchain_openai
 
8
  sentencepiece
9
  transformers
10
  ujson
 
2
  aiohttp
3
  datasets
4
  huggingface_hub
5
+ langchain==0.2.7
6
+ langchain_core==0.2.16
7
+ langchain_huggingface==0.0.3
8
+ langchain-openai==0.1.15
9
  sentencepiece
10
  transformers
11
  ujson