Gregor Betz commited on
Commit
9225b05
1 Parent(s): 5106691

config bugfix

Browse files
Files changed (3) hide show
  1. app.py +14 -10
  2. backend/config.py +3 -24
  3. config.yaml +3 -1
app.py CHANGED
@@ -16,15 +16,6 @@ from backend.svg_processing import postprocess_svg
16
 
17
  logging.basicConfig(level=logging.DEBUG)
18
 
19
- with open("config.yaml") as stream:
20
- try:
21
- DEMO_CONFIG = yaml.safe_load(stream)
22
- logging.debug(f"Config: {DEMO_CONFIG}")
23
- except yaml.YAMLError as exc:
24
- logging.error(f"Error loading config: {exc}")
25
- raise exc
26
-
27
-
28
 
29
  EXAMPLES = [
30
  ("We're a nature-loving family with three kids, have some money left, and no plans "
@@ -94,7 +85,20 @@ CHATBOT_INSTRUCTIONS = (
94
  )
95
 
96
  # config
97
- client_kwargs, guide_kwargs = process_config(DEMO_CONFIG)
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  logging.info(f"Reasoning guide expert model is {guide_kwargs['expert_model']}.")
99
 
100
 
 
16
 
17
  logging.basicConfig(level=logging.DEBUG)
18
 
 
 
 
 
 
 
 
 
 
19
 
20
  EXAMPLES = [
21
  ("We're a nature-loving family with three kids, have some money left, and no plans "
 
85
  )
86
 
87
  # config
88
+ with open("config.yaml") as stream:
89
+ try:
90
+ demo_config = yaml.safe_load(stream)
91
+ logging.debug(f"Config: {demo_config}")
92
+ except yaml.YAMLError as exc:
93
+ logging.error(f"Error loading config: {exc}")
94
+ gr.Error("Error loading config: {exc}")
95
+
96
+ try:
97
+ client_kwargs, guide_kwargs = process_config(demo_config)
98
+ except Exception as exc:
99
+ logging.error(f"Error processing config: {exc}")
100
+ gr.Error(f"Error processing config: {exc}")
101
+
102
  logging.info(f"Reasoning guide expert model is {guide_kwargs['expert_model']}.")
103
 
104
 
backend/config.py CHANGED
@@ -1,26 +1,5 @@
1
  import os
2
 
3
- # Default client
4
- INFERENCE_SERVER_URL = "https://api-inference.huggingface.co/models/{model_id}"
5
- MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
6
- CLIENT_MODEL_KWARGS = {
7
- "max_tokens": 800,
8
- "temperature": 0.6,
9
- }
10
-
11
- GUIDE_KWARGS = {
12
- "expert_model": "HuggingFaceH4/zephyr-7b-beta",
13
- # "meta-llama/Meta-Llama-3.1-70B-Instruct",
14
- "inference_server_url": "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta",
15
- # "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
16
- "llm_backend": "HFChat",
17
- "classifier_kwargs": {
18
- "model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
19
- "inference_server_url": "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
20
- "batch_size": 8,
21
- },
22
- }
23
-
24
 
25
  def process_config(config):
26
  if "HF_TOKEN" not in os.environ:
@@ -37,8 +16,8 @@ def process_config(config):
37
  raise ValueError("config.yaml is missing client url.")
38
  client_kwargs["api_key"] = os.getenv("HF_TOKEN")
39
  client_kwargs["llm_backend"] = "HFChat"
40
- client_kwargs["temperature"] = CLIENT_MODEL_KWARGS["temperature"]
41
- client_kwargs["max_tokens"] = CLIENT_MODEL_KWARGS["max_tokens"]
42
  else:
43
  raise ValueError("config.yaml is missing client_llm settings.")
44
 
@@ -67,7 +46,7 @@ def process_config(config):
67
  else:
68
  raise ValueError("config.yaml is missing classifier url.")
69
  if "batch_size" in config["classifier_llm"]:
70
- guide_kwargs["classifier_kwargs"]["batch_size"] = config["classifier_llm"]["batch_size"]
71
  else:
72
  raise ValueError("config.yaml is missing classifier batch_size.")
73
  guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
 
1
  import os
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def process_config(config):
5
  if "HF_TOKEN" not in os.environ:
 
16
  raise ValueError("config.yaml is missing client url.")
17
  client_kwargs["api_key"] = os.getenv("HF_TOKEN")
18
  client_kwargs["llm_backend"] = "HFChat"
19
+ client_kwargs["temperature"] = config["client_llm"].get("temperature",.6)
20
+ client_kwargs["max_tokens"] = config["client_llm"].get("max_tokens",800)
21
  else:
22
  raise ValueError("config.yaml is missing client_llm settings.")
23
 
 
46
  else:
47
  raise ValueError("config.yaml is missing classifier url.")
48
  if "batch_size" in config["classifier_llm"]:
49
+ guide_kwargs["classifier_kwargs"]["batch_size"] = int(config["classifier_llm"]["batch_size"])
50
  else:
51
  raise ValueError("config.yaml is missing classifier batch_size.")
52
  guide_kwargs["classifier_kwargs"]["api_key"] = os.getenv("HF_TOKEN") # classifier api key
config.yaml CHANGED
@@ -1,10 +1,12 @@
1
  client_llm:
2
  url: "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
3
  model_id: "HuggingFaceH4/zephyr-7b-beta"
 
 
4
  expert_llm:
5
  url: "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
6
  model_id: "HuggingFaceH4/zephyr-7b-beta"
7
  classifier_llm:
8
  model_id: "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
9
  url: "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
10
- batch_size: 8,
 
1
  client_llm:
2
  url: "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
3
  model_id: "HuggingFaceH4/zephyr-7b-beta"
4
+ max_tokens: 800
5
+ temperature: 0.6
6
  expert_llm:
7
  url: "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
8
  model_id: "HuggingFaceH4/zephyr-7b-beta"
9
  classifier_llm:
10
  model_id: "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
11
  url: "https://api-inference.huggingface.co/models/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
12
+ batch_size: 8