Gregor Betz commited on
Commit
379d37f
1 Parent(s): 574da11
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -26,20 +26,15 @@ CLIENT_MODEL_KWARGS = {
26
  }
27
 
28
  GUIDE_KWARGS = {
29
- "expert_model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
30
  # "meta-llama/Meta-Llama-3.1-70B-Instruct",
31
- # "accounts/fireworks/models/nous-hermes-2-mixtral-8x7b-dpo-fp8",
32
- # "accounts/fireworks/models/llama-v3-8b-instruct-hf",
33
- # "accounts/fireworks/models/nous-hermes-2-mixtral-8x7b-dpo-fp8",
34
- "inference_server_url": "https://api.fireworks.ai/inference/v1",
35
  # "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
36
- # "https://api.fireworks.ai/inference/v1",
37
- "llm_backend": "Fireworks",
38
  "classifier_kwargs": {
39
  "model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
40
- "inference_server_url": "https://sa710i91bnjvbhir.us-east-1.aws.endpoints.huggingface.cloud",
41
- # "inference_server_url": "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
42
- "batch_size": 128,
43
  },
44
  }
45
 
@@ -190,7 +185,7 @@ async def bot(
190
  if len(history_langchain_format) <= 1:
191
 
192
  guide_kwargs = copy.deepcopy(GUIDE_KWARGS)
193
- guide_kwargs["api_key"] = os.getenv("FW_TOKEN") # expert model api key
194
  guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
195
 
196
  guide_config = RecursiveBalancingGuideConfig(**guide_kwargs)
 
26
  }
27
 
28
  GUIDE_KWARGS = {
29
+ "expert_model": "HuggingFaceH4/zephyr-7b-beta",
30
  # "meta-llama/Meta-Llama-3.1-70B-Instruct",
31
+ "inference_server_url": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta",
 
 
 
32
  # "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
33
+ "llm_backend": "HFChat",
 
34
  "classifier_kwargs": {
35
  "model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
36
+ "inference_server_url": "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
37
+ "batch_size": 8,
 
38
  },
39
  }
40
 
 
185
  if len(history_langchain_format) <= 1:
186
 
187
  guide_kwargs = copy.deepcopy(GUIDE_KWARGS)
188
+ guide_kwargs["api_key"] = os.getenv("HF_TOKEN") # expert model api key
189
  guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
190
 
191
  guide_config = RecursiveBalancingGuideConfig(**guide_kwargs)