|
from huggingface_hub import scan_cache_dir
|
|
|
|
|
|
def get_sampler_names():
|
|
"""Get sampler name list.
|
|
|
|
Returns:
|
|
list: sampler name list
|
|
"""
|
|
sampler_names = [
|
|
"DDIM",
|
|
"Euler",
|
|
"Euler a",
|
|
"DPM2 Karras",
|
|
"DPM2 a Karras",
|
|
]
|
|
return sampler_names
|
|
|
|
|
|
def get_sam_model_ids():
|
|
"""Get SAM model ids list.
|
|
|
|
Returns:
|
|
list: SAM model ids list
|
|
"""
|
|
sam_model_ids = [
|
|
"sam2_hiera_large.pt",
|
|
"sam2_hiera_base_plus.pt",
|
|
"sam2_hiera_small.pt",
|
|
"sam2_hiera_tiny.pt",
|
|
"sam_vit_h_4b8939.pth",
|
|
"sam_vit_l_0b3195.pth",
|
|
"sam_vit_b_01ec64.pth",
|
|
"sam_hq_vit_h.pth",
|
|
"sam_hq_vit_l.pth",
|
|
"sam_hq_vit_b.pth",
|
|
"FastSAM-x.pt",
|
|
"FastSAM-s.pt",
|
|
"mobile_sam.pt",
|
|
]
|
|
return sam_model_ids
|
|
|
|
|
|
inp_list_from_cache = None
|
|
|
|
|
|
def get_inp_model_ids():
|
|
"""Get inpainting model ids list.
|
|
|
|
Returns:
|
|
list: model ids list
|
|
"""
|
|
global inp_list_from_cache
|
|
model_ids = [
|
|
"stabilityai/stable-diffusion-2-inpainting",
|
|
"Uminosachi/dreamshaper_8Inpainting",
|
|
"Uminosachi/deliberate_v3-inpainting",
|
|
"Uminosachi/realisticVisionV51_v51VAE-inpainting",
|
|
"Uminosachi/revAnimated_v121Inp-inpainting",
|
|
"runwayml/stable-diffusion-inpainting",
|
|
]
|
|
if inp_list_from_cache is not None and isinstance(inp_list_from_cache, list):
|
|
model_ids.extend(inp_list_from_cache)
|
|
return model_ids
|
|
try:
|
|
hf_cache_info = scan_cache_dir()
|
|
inpaint_repos = []
|
|
for repo in hf_cache_info.repos:
|
|
if repo.repo_type == "model" and "inpaint" in repo.repo_id.lower() and repo.repo_id not in model_ids:
|
|
inpaint_repos.append(repo.repo_id)
|
|
inp_list_from_cache = sorted(inpaint_repos, reverse=True, key=lambda x: x.split("/")[-1])
|
|
model_ids.extend(inp_list_from_cache)
|
|
return model_ids
|
|
except Exception:
|
|
return model_ids
|
|
|
|
|
|
def get_cleaner_model_ids():
|
|
"""Get cleaner model ids list.
|
|
|
|
Returns:
|
|
list: model ids list
|
|
"""
|
|
model_ids = [
|
|
"lama",
|
|
"ldm",
|
|
"zits",
|
|
"mat",
|
|
"fcf",
|
|
"manga",
|
|
]
|
|
return model_ids
|
|
|
|
|
|
def get_padding_mode_names():
|
|
"""Get padding mode name list.
|
|
|
|
Returns:
|
|
list: padding mode name list
|
|
"""
|
|
padding_mode_names = [
|
|
"constant",
|
|
"edge",
|
|
"reflect",
|
|
"mean",
|
|
"median",
|
|
"maximum",
|
|
"minimum",
|
|
]
|
|
return padding_mode_names
|
|
|