Spaces:
Runtime error
Runtime error
Update app.py
Browse filesadded config builder
app.py
CHANGED
@@ -110,6 +110,72 @@ examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
|
|
110 |
# when user do not provide a token.
|
111 |
COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
|
115 |
runner = LogsViewRunner()
|
@@ -190,6 +256,7 @@ with gr.Blocks() as demo:
|
|
190 |
gr.Markdown(MARKDOWN_DESCRIPTION)
|
191 |
|
192 |
with gr.Row():
|
|
|
193 |
filename = gr.Textbox(visible=False, label="filename")
|
194 |
config = gr.Code(language="yaml", lines=10, label="config.yaml")
|
195 |
with gr.Column():
|
|
|
110 |
# when user do not provide a token.
|
111 |
COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
|
112 |
|
113 |
+
# config builder
|
114 |
+
import yaml
|
115 |
+
|
116 |
+
def generate_config(base_model, models, layer_range, merge_method):
|
117 |
+
slices = []
|
118 |
+
for model in models:
|
119 |
+
slice_config = {
|
120 |
+
"sources": [
|
121 |
+
{
|
122 |
+
"model": model,
|
123 |
+
"layer_range": layer_range
|
124 |
+
}
|
125 |
+
]
|
126 |
+
}
|
127 |
+
slices.append(slice_config)
|
128 |
+
|
129 |
+
config = {
|
130 |
+
"slices": slices,
|
131 |
+
"merge_method": merge_method,
|
132 |
+
"base_model": base_model,
|
133 |
+
"parameters": {
|
134 |
+
"t": [
|
135 |
+
{
|
136 |
+
"filter": "self_attn",
|
137 |
+
"value": [0, 0.5, 0.3, 0.7, 1]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"filter": "mlp",
|
141 |
+
"value": [1, 0.5, 0.7, 0.3, 0]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"value": 0.5
|
145 |
+
}
|
146 |
+
]
|
147 |
+
},
|
148 |
+
"dtype": "bfloat16"
|
149 |
+
}
|
150 |
+
|
151 |
+
return yaml.dump(config)
|
152 |
+
|
153 |
+
|
154 |
+
# Add these imports
|
155 |
+
from functools import partial
|
156 |
+
from itertools import chain
|
157 |
+
|
158 |
+
# Configure dropdown options
|
159 |
+
BASE_MODELS = ["bert-base-uncased", "distilbert-base-uncased", ...] # Add other base models here
|
160 |
+
MERGE_METHODS = ["linear", "slerp", ...] # Add other merge methods here
|
161 |
+
LAYER_RANGE = range(32)
|
162 |
+
|
163 |
+
# Create input objects
|
164 |
+
input_base_model = gr.Dropdown(label='Base Model', choices=BASE_MODELS)
|
165 |
+
input_models = gr.Multiselect(label='Models', choices=BASE_MODELS)
|
166 |
+
input_layer_range = gr.NumberSlider(minimum=0, maximum=32, step=1, label='Layer Range')
|
167 |
+
input_merge_method = gr.Dropdown(label='Merge Method', choices=MERGE_METHODS)
|
168 |
+
|
169 |
+
# Wrap generate_config in a partial function to fix the signature
|
170 |
+
partial_generate_config = partial(generate_config, base_model=input_base_model, merge_method=input_merge_method)
|
171 |
+
|
172 |
+
# Generate config block
|
173 |
+
gen_config_block = gr.Block()
|
174 |
+
with gen_config_block:
|
175 |
+
generated_config = gr.outputs.Textbox(label='Generated Config', interactive=False)
|
176 |
+
btn_generate_config = gr.Button('Generate Config', variant='secondary')
|
177 |
+
btn_generate_config.click(fn=partial_generate_config, inputs=[input_base_model, input_models, input_layer_range], outputs=[generated_config])
|
178 |
+
|
179 |
|
180 |
def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
|
181 |
runner = LogsViewRunner()
|
|
|
256 |
gr.Markdown(MARKDOWN_DESCRIPTION)
|
257 |
|
258 |
with gr.Row():
|
259 |
+
gen_config_block
|
260 |
filename = gr.Textbox(visible=False, label="filename")
|
261 |
config = gr.Code(language="yaml", lines=10, label="config.yaml")
|
262 |
with gr.Column():
|