sgbaird commited on
Commit
f618644
·
1 Parent(s): b3dff97

pydantic validation of x1, x2, ... syntax

Browse files
Files changed (1) hide show
  1. app.py +75 -8
app.py CHANGED
@@ -3,6 +3,13 @@ import gradio as gr
3
  import pandas as pd
4
  from sklearn.preprocessing import MinMaxScaler
5
  from surrogate import CrabNetSurrogateModel, PARAM_BOUNDS
 
 
 
 
 
 
 
6
 
7
  model = CrabNetSurrogateModel()
8
 
@@ -34,15 +41,82 @@ example_parameterization = {
34
  "train_frac": 0.5,
35
  }
36
 
37
- # Define the output parameters
38
  example_results = model.surrogate_evaluate([example_parameterization])
39
  example_result = example_results[0]
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def evaluate(*args):
43
  # Create a DataFrame with the parameter names and scaled values
44
  params_df = pd.DataFrame([args], columns=[param["name"] for param in PARAM_BOUNDS])
45
 
 
 
 
46
  # Reverse the scaling for each parameter and reverse the renaming for choice parameters
47
  for param_info in PARAM_BOUNDS:
48
  key = param_info["name"]
@@ -65,13 +139,6 @@ def evaluate(*args):
65
  return results_list
66
 
67
 
68
- scalers = {
69
- param_info["name"]: MinMaxScaler()
70
- for param_info in PARAM_BOUNDS
71
- if param_info["type"] == "range"
72
- }
73
-
74
-
75
  def get_interface(param_info, numeric_index, choice_index):
76
  key = param_info["name"]
77
  default_value = example_parameterization[key]
 
3
  import pandas as pd
4
  from sklearn.preprocessing import MinMaxScaler
5
  from surrogate import CrabNetSurrogateModel, PARAM_BOUNDS
6
+ from pydantic import (
7
+ BaseModel,
8
+ ValidationError,
9
+ ValidationInfo,
10
+ field_validator,
11
+ model_validator,
12
+ )
13
 
14
  model = CrabNetSurrogateModel()
15
 
 
41
  "train_frac": 0.5,
42
  }
43
 
 
44
  example_results = model.surrogate_evaluate([example_parameterization])
45
  example_result = example_results[0]
46
 
47
+ scalers = {
48
+ param_info["name"]: MinMaxScaler()
49
+ for param_info in PARAM_BOUNDS
50
+ if param_info["type"] == "range"
51
+ }
52
+
53
+
54
+ class BlindedParameterization(BaseModel):
55
+ x1: float # int
56
+ x2: float
57
+ x3: float # int
58
+ x4: float # int
59
+ x5: float
60
+ x6: float
61
+ x7: float # int
62
+ x8: float
63
+ x9: float
64
+ x10: float # int
65
+ x11: float # int
66
+ x12: float
67
+ x13: float # int
68
+ x14: float # int
69
+ x15: float
70
+ x16: float # int
71
+ x17: float # int
72
+ x18: float # int
73
+ x19: float
74
+ x20: float
75
+ c1: bool
76
+ c2: str
77
+ c3: str
78
+ f1: float
79
+
80
+ @field_validator("*")
81
+ def check_bounds(cls, v: int, info: ValidationInfo) -> int:
82
+ param = next(
83
+ (item for item in PARAM_BOUNDS if item["name"] == info.field_name),
84
+ None,
85
+ )
86
+ if param is None:
87
+ return v
88
+
89
+ if param["type"] == "range":
90
+ min_val, max_val = param["bounds"]
91
+ if not min_val <= v <= max_val:
92
+ raise ValueError(
93
+ f"{info.field_name} must be between {min_val} and {max_val}"
94
+ )
95
+ elif param["type"] == "choice":
96
+ if v not in param["values"]:
97
+ raise ValueError(f"{info.field_name} must be one of {param['values']}")
98
+
99
+ return v
100
+
101
+ @model_validator(mode="after")
102
+ def check_constraints(self) -> "BlindedParameterization":
103
+ if self.x19 > self.x20:
104
+ raise ValueError(
105
+ f"Received x19={self.x19} which should be less than x20={self.x20}"
106
+ )
107
+ if self.x6 + self.x15 > 1.0:
108
+ raise ValueError(
109
+ f"Received x6={self.x6} and x15={self.x15} which should sum to less than or equal to 1.0" # noqa: E501
110
+ )
111
+
112
 
113
  def evaluate(*args):
114
  # Create a DataFrame with the parameter names and scaled values
115
  params_df = pd.DataFrame([args], columns=[param["name"] for param in PARAM_BOUNDS])
116
 
117
+ # error checking
118
+ BlindedParameterization(**params_df.to_dict("records")[0])
119
+
120
  # Reverse the scaling for each parameter and reverse the renaming for choice parameters
121
  for param_info in PARAM_BOUNDS:
122
  key = param_info["name"]
 
139
  return results_list
140
 
141
 
 
 
 
 
 
 
 
142
  def get_interface(param_info, numeric_index, choice_index):
143
  key = param_info["name"]
144
  default_value = example_parameterization[key]