kohya_ss / kohya_gui /class_source_model.py
ABCCCYYY's picture
Upload folder using huggingface_hub
cff1674 verified
import gradio as gr
import os
from .common_gui import (
get_file_path,
get_folder_path,
set_pretrained_model_name_or_path_input,
scriptdir,
list_dirs,
list_files,
create_refresh_button,
)
from .class_gui_config import KohyaSSGUIConfig
folder_symbol = "\U0001f4c2" # πŸ“‚
refresh_symbol = "\U0001f504" # πŸ”„
save_style_symbol = "\U0001f4be" # πŸ’Ύ
document_symbol = "\U0001F4C4" # πŸ“„
default_models = [
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/stable-diffusion-xl-refiner-1.0",
"stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned",
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-base",
"stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2",
"runwayml/stable-diffusion-v1-5",
"CompVis/stable-diffusion-v1-4",
]
class SourceModel:
def __init__(
self,
save_model_as_choices=[
"same as source model",
"ckpt",
"diffusers",
"diffusers_safetensors",
"safetensors",
],
save_precision_choices=[
"float",
"fp16",
"bf16",
],
headless=False,
finetuning=False,
config: KohyaSSGUIConfig = {},
):
self.headless = headless
self.save_model_as_choices = save_model_as_choices
self.finetuning = finetuning
self.config = config
# Set default directories if not provided
self.current_models_dir = self.config.get(
"model.models_dir", os.path.join(scriptdir, "models")
)
self.current_train_data_dir = self.config.get(
"model.train_data_dir", os.path.join(scriptdir, "data")
)
self.current_dataset_config_dir = self.config.get(
"model.dataset_config", os.path.join(scriptdir, "dataset_config")
)
model_checkpoints = list(
list_files(
self.current_models_dir, exts=[".ckpt", ".safetensors"], all=True
)
)
def list_models(path):
self.current_models_dir = (
path if os.path.isdir(path) else os.path.dirname(path)
)
return default_models + list(
list_files(path, exts=[".ckpt", ".safetensors"], all=True)
)
def list_train_data_dirs(path):
self.current_train_data_dir = path if not path == "" else "."
return list(list_dirs(self.current_train_data_dir))
def list_dataset_config_dirs(path: str) -> list:
"""
List directories and toml files in the dataset_config directory.
Parameters:
- path (str): The path to list directories and files from.
Returns:
- list: A list of directories and files.
"""
current_dataset_config_dir = path if not path == "" else "."
# Lists all .json files in the current configuration directory, used for populating dropdown choices.
return list(
list_files(current_dataset_config_dir, exts=[".toml"], all=True)
)
with gr.Accordion("Model", open=True):
with gr.Column(), gr.Group():
model_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False)
model_ext_name = gr.Textbox(value="Model types", visible=False)
# Define the input elements
with gr.Row():
with gr.Column(), gr.Row():
self.model_list = gr.Textbox(visible=False, value="")
self.pretrained_model_name_or_path = gr.Dropdown(
label="Pretrained model name or path",
choices=default_models + model_checkpoints,
value=self.config.get("model.models_dir", "runwayml/stable-diffusion-v1-5"),
allow_custom_value=True,
visible=True,
min_width=100,
)
create_refresh_button(
self.pretrained_model_name_or_path,
lambda: None,
lambda: {"choices": list_models(self.current_models_dir)},
"open_folder_small",
)
self.pretrained_model_name_or_path_file = gr.Button(
document_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
self.pretrained_model_name_or_path_file.click(
get_file_path,
inputs=[self.pretrained_model_name_or_path, model_ext, model_ext_name],
outputs=self.pretrained_model_name_or_path,
show_progress=False,
)
self.pretrained_model_name_or_path_folder = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
self.pretrained_model_name_or_path_folder.click(
get_folder_path,
inputs=self.pretrained_model_name_or_path,
outputs=self.pretrained_model_name_or_path,
show_progress=False,
)
with gr.Column(), gr.Row():
self.output_name = gr.Textbox(
label="Trained Model output name",
placeholder="(Name of the model to output)",
value=self.config.get("model.output_name", "last"),
interactive=True,
)
with gr.Row():
with gr.Column(), gr.Row():
self.train_data_dir = gr.Dropdown(
label=(
"Image folder (containing training images subfolders)"
if not finetuning
else "Image folder (containing training images)"
),
choices=[""]
+ list_train_data_dirs(self.current_train_data_dir),
value=self.config.get("model.train_data_dir", ""),
interactive=True,
allow_custom_value=True,
)
create_refresh_button(
self.train_data_dir,
lambda: None,
lambda: {
"choices": [""]
+ list_train_data_dirs(self.current_train_data_dir)
},
"open_folder_small",
)
self.train_data_dir_folder = gr.Button(
"πŸ“‚",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
self.train_data_dir_folder.click(
get_folder_path,
outputs=self.train_data_dir,
show_progress=False,
)
with gr.Column(), gr.Row():
# Toml directory dropdown
self.dataset_config = gr.Dropdown(
label="Dataset config file (Optional. Select the toml configuration file to use for the dataset)",
choices=[self.config.get("model.dataset_config", "")]
+ list_dataset_config_dirs(self.current_dataset_config_dir),
value=self.config.get("model.dataset_config", ""),
interactive=True,
allow_custom_value=True,
)
# Refresh button for dataset_config directory
create_refresh_button(
self.dataset_config,
lambda: None,
lambda: {
"choices": [""]
+ list_dataset_config_dirs(
self.current_dataset_config_dir
)
},
"open_folder_small",
)
# Toml directory button
self.dataset_config_folder = gr.Button(
document_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
# Toml directory button click event
self.dataset_config_folder.click(
get_file_path,
inputs=[
self.dataset_config,
gr.Textbox(value="*.toml", visible=False),
gr.Textbox(value="Dataset config types", visible=False),
],
outputs=self.dataset_config,
show_progress=False,
)
# Change event for dataset_config directory dropdown
self.dataset_config.change(
fn=lambda path: gr.Dropdown(
choices=[""] + list_dataset_config_dirs(path)
),
inputs=self.dataset_config,
outputs=self.dataset_config,
show_progress=False,
)
with gr.Row():
with gr.Column():
with gr.Row():
self.v2 = gr.Checkbox(
label="v2", value=False, visible=False, min_width=60
)
self.v_parameterization = gr.Checkbox(
label="v_parameterization",
value=False,
visible=False,
min_width=130,
)
self.sdxl_checkbox = gr.Checkbox(
label="SDXL",
value=False,
visible=False,
min_width=60,
)
with gr.Column():
gr.Group(visible=False)
with gr.Row():
self.training_comment = gr.Textbox(
label="Training comment",
placeholder="(Optional) Add training comment to be included in metadata",
interactive=True,
value=self.config.get("model.training_comment", ""),
)
with gr.Row():
self.save_model_as = gr.Radio(
save_model_as_choices,
label="Save trained model as",
value=self.config.get("model.save_model_as", "safetensors"),
)
self.save_precision = gr.Radio(
save_precision_choices,
label="Save precision",
value=self.config.get("model.save_precision", "fp16"),
)
self.pretrained_model_name_or_path.change(
fn=lambda path: set_pretrained_model_name_or_path_input(
path, refresh_method=list_models
),
inputs=[
self.pretrained_model_name_or_path,
],
outputs=[
self.pretrained_model_name_or_path,
self.v2,
self.v_parameterization,
self.sdxl_checkbox,
],
show_progress=False,
)
self.train_data_dir.change(
fn=lambda path: gr.Dropdown(
choices=[""] + list_train_data_dirs(path)
),
inputs=self.train_data_dir,
outputs=self.train_data_dir,
show_progress=False,
)