|
import os |
|
import sys |
|
|
|
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) |
|
from pytorch_lightning import seed_everything |
|
from scripts.demo.streamlit_helpers import * |
|
from scripts.demo.sv3d_helpers import * |
|
|
|
SAVE_PATH = "outputs/demo/vid/" |
|
|
|
VERSION2SPECS = { |
|
"svd": { |
|
"T": 14, |
|
"H": 576, |
|
"W": 1024, |
|
"C": 4, |
|
"f": 8, |
|
"config": "configs/inference/svd.yaml", |
|
"ckpt": "checkpoints/svd.safetensors", |
|
"options": { |
|
"discretization": 1, |
|
"cfg": 2.5, |
|
"sigma_min": 0.002, |
|
"sigma_max": 700.0, |
|
"rho": 7.0, |
|
"guider": 2, |
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], |
|
"num_steps": 25, |
|
}, |
|
}, |
|
"svd_image_decoder": { |
|
"T": 14, |
|
"H": 576, |
|
"W": 1024, |
|
"C": 4, |
|
"f": 8, |
|
"config": "configs/inference/svd_image_decoder.yaml", |
|
"ckpt": "checkpoints/svd_image_decoder.safetensors", |
|
"options": { |
|
"discretization": 1, |
|
"cfg": 2.5, |
|
"sigma_min": 0.002, |
|
"sigma_max": 700.0, |
|
"rho": 7.0, |
|
"guider": 2, |
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], |
|
"num_steps": 25, |
|
}, |
|
}, |
|
"svd_xt": { |
|
"T": 25, |
|
"H": 576, |
|
"W": 1024, |
|
"C": 4, |
|
"f": 8, |
|
"config": "configs/inference/svd.yaml", |
|
"ckpt": "checkpoints/svd_xt.safetensors", |
|
"options": { |
|
"discretization": 1, |
|
"cfg": 3.0, |
|
"min_cfg": 1.5, |
|
"sigma_min": 0.002, |
|
"sigma_max": 700.0, |
|
"rho": 7.0, |
|
"guider": 2, |
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], |
|
"num_steps": 30, |
|
"decoding_t": 14, |
|
}, |
|
}, |
|
"svd_xt_image_decoder": { |
|
"T": 25, |
|
"H": 576, |
|
"W": 1024, |
|
"C": 4, |
|
"f": 8, |
|
"config": "configs/inference/svd_image_decoder.yaml", |
|
"ckpt": "checkpoints/svd_xt_image_decoder.safetensors", |
|
"options": { |
|
"discretization": 1, |
|
"cfg": 3.0, |
|
"min_cfg": 1.5, |
|
"sigma_min": 0.002, |
|
"sigma_max": 700.0, |
|
"rho": 7.0, |
|
"guider": 2, |
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], |
|
"num_steps": 30, |
|
"decoding_t": 14, |
|
}, |
|
}, |
|
"sv3d_u": { |
|
"T": 21, |
|
"H": 576, |
|
"W": 576, |
|
"C": 4, |
|
"f": 8, |
|
"config": "configs/inference/sv3d_u.yaml", |
|
"ckpt": "checkpoints/sv3d_u.safetensors", |
|
"options": { |
|
"discretization": 1, |
|
"cfg": 2.5, |
|
"sigma_min": 0.002, |
|
"sigma_max": 700.0, |
|
"rho": 7.0, |
|
"guider": 3, |
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], |
|
"num_steps": 50, |
|
"decoding_t": 14, |
|
}, |
|
}, |
|
"sv3d_p": { |
|
"T": 21, |
|
"H": 576, |
|
"W": 576, |
|
"C": 4, |
|
"f": 8, |
|
"config": "configs/inference/sv3d_p.yaml", |
|
"ckpt": "checkpoints/sv3d_p.safetensors", |
|
"options": { |
|
"discretization": 1, |
|
"cfg": 2.5, |
|
"sigma_min": 0.002, |
|
"sigma_max": 700.0, |
|
"rho": 7.0, |
|
"guider": 3, |
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], |
|
"num_steps": 50, |
|
"decoding_t": 14, |
|
}, |
|
}, |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
st.title("Stable Video Diffusion / SV3D") |
|
version = st.selectbox( |
|
"Model Version", |
|
[k for k in VERSION2SPECS.keys()], |
|
0, |
|
) |
|
version_dict = VERSION2SPECS[version] |
|
if st.checkbox("Load Model"): |
|
mode = "img2vid" |
|
else: |
|
mode = "skip" |
|
|
|
H = st.sidebar.number_input( |
|
"H", value=version_dict["H"], min_value=64, max_value=2048 |
|
) |
|
W = st.sidebar.number_input( |
|
"W", value=version_dict["W"], min_value=64, max_value=2048 |
|
) |
|
T = st.sidebar.number_input( |
|
"T", value=version_dict["T"], min_value=0, max_value=128 |
|
) |
|
C = version_dict["C"] |
|
F = version_dict["f"] |
|
options = version_dict["options"] |
|
|
|
if mode != "skip": |
|
state = init_st(version_dict, load_filter=True) |
|
if state["msg"]: |
|
st.info(state["msg"]) |
|
model = state["model"] |
|
|
|
ukeys = set( |
|
get_unique_embedder_keys_from_conditioner(state["model"].conditioner) |
|
) |
|
|
|
value_dict = init_embedder_options( |
|
ukeys, |
|
{}, |
|
) |
|
|
|
if "fps" not in ukeys: |
|
value_dict["fps"] = 10 |
|
|
|
value_dict["image_only_indicator"] = 0 |
|
|
|
if mode == "img2vid": |
|
img = load_img_for_prediction(W, H) |
|
if "sv3d" in version: |
|
cond_aug = 1e-5 |
|
else: |
|
cond_aug = st.number_input( |
|
"Conditioning augmentation:", value=0.02, min_value=0.0 |
|
) |
|
value_dict["cond_frames_without_noise"] = img |
|
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img) |
|
value_dict["cond_aug"] = cond_aug |
|
|
|
if "sv3d_p" in version: |
|
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90) |
|
trajectory = st.selectbox( |
|
"Trajectory", |
|
["same elevation", "dynamic"], |
|
0, |
|
) |
|
if trajectory == "same elevation": |
|
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T) |
|
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:] |
|
elif trajectory == "dynamic": |
|
azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg) |
|
value_dict["polars_rad"] = np.deg2rad(90) - elev_rad |
|
value_dict["azimuths_rad"] = azim_rad |
|
elif "sv3d_u" in version: |
|
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90) |
|
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T) |
|
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:] |
|
|
|
seed = st.sidebar.number_input( |
|
"seed", value=23, min_value=0, max_value=int(1e9) |
|
) |
|
seed_everything(seed) |
|
|
|
save_locally, save_path = init_save_locally( |
|
os.path.join(SAVE_PATH, version), init_value=True |
|
) |
|
|
|
if "sv3d" in version: |
|
plot_save_path = os.path.join(save_path, "plot_3D.png") |
|
plot_3D( |
|
azim=value_dict["azimuths_rad"], |
|
polar=value_dict["polars_rad"], |
|
save_path=plot_save_path, |
|
dynamic=("sv3d_p" in version), |
|
) |
|
st.image( |
|
plot_save_path, |
|
f"3D camera trajectory", |
|
) |
|
|
|
options["num_frames"] = T |
|
|
|
sampler, num_rows, num_cols = init_sampling(options=options) |
|
num_samples = num_rows * num_cols |
|
|
|
decoding_t = st.number_input( |
|
"Decode t frames at a time (set small if you are low on VRAM)", |
|
value=options.get("decoding_t", T), |
|
min_value=1, |
|
max_value=int(1e9), |
|
) |
|
|
|
if st.checkbox("Overwrite fps in mp4 generator", False): |
|
saving_fps = st.number_input( |
|
f"saving video at fps:", value=value_dict["fps"], min_value=1 |
|
) |
|
else: |
|
saving_fps = value_dict["fps"] |
|
|
|
if st.button("Sample"): |
|
out = do_sample( |
|
model, |
|
sampler, |
|
value_dict, |
|
num_samples, |
|
H, |
|
W, |
|
C, |
|
F, |
|
T=T, |
|
batch2model_input=["num_video_frames", "image_only_indicator"], |
|
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None), |
|
force_cond_zero_embeddings=options.get( |
|
"force_cond_zero_embeddings", None |
|
), |
|
return_latents=False, |
|
decoding_t=decoding_t, |
|
) |
|
|
|
if isinstance(out, (tuple, list)): |
|
samples, samples_z = out |
|
else: |
|
samples = out |
|
samples_z = None |
|
|
|
if save_locally: |
|
save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps) |
|
|