K00B404 commited on
Commit
2b0c812
1 Parent(s): cc0dc31

Update app.py

Browse files

added config builder

Files changed (1) hide show
  1. app.py +67 -0
app.py CHANGED
@@ -110,6 +110,72 @@ examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
110
  # when user do not provide a token.
111
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
115
  runner = LogsViewRunner()
@@ -190,6 +256,7 @@ with gr.Blocks() as demo:
190
  gr.Markdown(MARKDOWN_DESCRIPTION)
191
 
192
  with gr.Row():
 
193
  filename = gr.Textbox(visible=False, label="filename")
194
  config = gr.Code(language="yaml", lines=10, label="config.yaml")
195
  with gr.Column():
 
110
  # when user do not provide a token.
111
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
112
 
113
+ # config builder
114
+ import yaml
115
+
116
+ def generate_config(base_model, models, layer_range, merge_method):
117
+ slices = []
118
+ for model in models:
119
+ slice_config = {
120
+ "sources": [
121
+ {
122
+ "model": model,
123
+ "layer_range": layer_range
124
+ }
125
+ ]
126
+ }
127
+ slices.append(slice_config)
128
+
129
+ config = {
130
+ "slices": slices,
131
+ "merge_method": merge_method,
132
+ "base_model": base_model,
133
+ "parameters": {
134
+ "t": [
135
+ {
136
+ "filter": "self_attn",
137
+ "value": [0, 0.5, 0.3, 0.7, 1]
138
+ },
139
+ {
140
+ "filter": "mlp",
141
+ "value": [1, 0.5, 0.7, 0.3, 0]
142
+ },
143
+ {
144
+ "value": 0.5
145
+ }
146
+ ]
147
+ },
148
+ "dtype": "bfloat16"
149
+ }
150
+
151
+ return yaml.dump(config)
152
+
153
+
154
+ # Add these imports
155
+ from functools import partial
156
+ from itertools import chain
157
+
158
+ # Configure dropdown options
159
+ BASE_MODELS = ["bert-base-uncased", "distilbert-base-uncased", ...] # Add other base models here
160
+ MERGE_METHODS = ["linear", "slerp", ...] # Add other merge methods here
161
+ LAYER_RANGE = range(32)
162
+
163
+ # Create input objects
164
+ input_base_model = gr.Dropdown(label='Base Model', choices=BASE_MODELS)
165
+ input_models = gr.Multiselect(label='Models', choices=BASE_MODELS)
166
+ input_layer_range = gr.NumberSlider(minimum=0, maximum=32, step=1, label='Layer Range')
167
+ input_merge_method = gr.Dropdown(label='Merge Method', choices=MERGE_METHODS)
168
+
169
+ # Wrap generate_config in a partial function to fix the signature
170
+ partial_generate_config = partial(generate_config, base_model=input_base_model, merge_method=input_merge_method)
171
+
172
+ # Generate config block
173
+ gen_config_block = gr.Block()
174
+ with gen_config_block:
175
+ generated_config = gr.outputs.Textbox(label='Generated Config', interactive=False)
176
+ btn_generate_config = gr.Button('Generate Config', variant='secondary')
177
+ btn_generate_config.click(fn=partial_generate_config, inputs=[input_base_model, input_models, input_layer_range], outputs=[generated_config])
178
+
179
 
180
  def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
181
  runner = LogsViewRunner()
 
256
  gr.Markdown(MARKDOWN_DESCRIPTION)
257
 
258
  with gr.Row():
259
+ gen_config_block
260
  filename = gr.Textbox(visible=False, label="filename")
261
  config = gr.Code(language="yaml", lines=10, label="config.yaml")
262
  with gr.Column():