import os import sys sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * from scripts.demo.sv3d_helpers import * SAVE_PATH = "outputs/demo/vid/" VERSION2SPECS = { "svd": { "T": 14, "H": 576, "W": 1024, "C": 4, "f": 8, "config": "configs/inference/svd.yaml", "ckpt": "checkpoints/svd.safetensors", "options": { "discretization": 1, "cfg": 2.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 2, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 25, }, }, "svd_image_decoder": { "T": 14, "H": 576, "W": 1024, "C": 4, "f": 8, "config": "configs/inference/svd_image_decoder.yaml", "ckpt": "checkpoints/svd_image_decoder.safetensors", "options": { "discretization": 1, "cfg": 2.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 2, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 25, }, }, "svd_xt": { "T": 25, "H": 576, "W": 1024, "C": 4, "f": 8, "config": "configs/inference/svd.yaml", "ckpt": "checkpoints/svd_xt.safetensors", "options": { "discretization": 1, "cfg": 3.0, "min_cfg": 1.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 2, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 30, "decoding_t": 14, }, }, "svd_xt_image_decoder": { "T": 25, "H": 576, "W": 1024, "C": 4, "f": 8, "config": "configs/inference/svd_image_decoder.yaml", "ckpt": "checkpoints/svd_xt_image_decoder.safetensors", "options": { "discretization": 1, "cfg": 3.0, "min_cfg": 1.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 2, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 30, "decoding_t": 14, }, }, "sv3d_u": { "T": 21, "H": 576, "W": 576, "C": 4, "f": 8, "config": "configs/inference/sv3d_u.yaml", "ckpt": "checkpoints/sv3d_u.safetensors", "options": { "discretization": 1, "cfg": 2.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 3, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 50, "decoding_t": 14, }, }, "sv3d_p": { "T": 21, "H": 576, "W": 576, "C": 4, "f": 8, "config": "configs/inference/sv3d_p.yaml", "ckpt": "checkpoints/sv3d_p.safetensors", "options": { "discretization": 1, "cfg": 2.5, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, "guider": 3, "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], "num_steps": 50, "decoding_t": 14, }, }, } if __name__ == "__main__": st.title("Stable Video Diffusion / SV3D") version = st.selectbox( "Model Version", [k for k in VERSION2SPECS.keys()], 0, ) version_dict = VERSION2SPECS[version] if st.checkbox("Load Model"): mode = "img2vid" else: mode = "skip" H = st.sidebar.number_input( "H", value=version_dict["H"], min_value=64, max_value=2048 ) W = st.sidebar.number_input( "W", value=version_dict["W"], min_value=64, max_value=2048 ) T = st.sidebar.number_input( "T", value=version_dict["T"], min_value=0, max_value=128 ) C = version_dict["C"] F = version_dict["f"] options = version_dict["options"] if mode != "skip": state = init_st(version_dict, load_filter=True) if state["msg"]: st.info(state["msg"]) model = state["model"] ukeys = set( get_unique_embedder_keys_from_conditioner(state["model"].conditioner) ) value_dict = init_embedder_options( ukeys, {}, ) if "fps" not in ukeys: value_dict["fps"] = 10 value_dict["image_only_indicator"] = 0 if mode == "img2vid": img = load_img_for_prediction(W, H) if "sv3d" in version: cond_aug = 1e-5 else: cond_aug = st.number_input( "Conditioning augmentation:", value=0.02, min_value=0.0 ) value_dict["cond_frames_without_noise"] = img value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img) value_dict["cond_aug"] = cond_aug if "sv3d_p" in version: elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90) trajectory = st.selectbox( "Trajectory", ["same elevation", "dynamic"], 0, ) if trajectory == "same elevation": value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T) value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:] elif trajectory == "dynamic": azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg) value_dict["polars_rad"] = np.deg2rad(90) - elev_rad value_dict["azimuths_rad"] = azim_rad elif "sv3d_u" in version: elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90) value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T) value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:] seed = st.sidebar.number_input( "seed", value=23, min_value=0, max_value=int(1e9) ) seed_everything(seed) save_locally, save_path = init_save_locally( os.path.join(SAVE_PATH, version), init_value=True ) if "sv3d" in version: plot_save_path = os.path.join(save_path, "plot_3D.png") plot_3D( azim=value_dict["azimuths_rad"], polar=value_dict["polars_rad"], save_path=plot_save_path, dynamic=("sv3d_p" in version), ) st.image( plot_save_path, f"3D camera trajectory", ) options["num_frames"] = T sampler, num_rows, num_cols = init_sampling(options=options) num_samples = num_rows * num_cols decoding_t = st.number_input( "Decode t frames at a time (set small if you are low on VRAM)", value=options.get("decoding_t", T), min_value=1, max_value=int(1e9), ) if st.checkbox("Overwrite fps in mp4 generator", False): saving_fps = st.number_input( f"saving video at fps:", value=value_dict["fps"], min_value=1 ) else: saving_fps = value_dict["fps"] if st.button("Sample"): out = do_sample( model, sampler, value_dict, num_samples, H, W, C, F, T=T, batch2model_input=["num_video_frames", "image_only_indicator"], force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None), force_cond_zero_embeddings=options.get( "force_cond_zero_embeddings", None ), return_latents=False, decoding_t=decoding_t, ) if isinstance(out, (tuple, list)): samples, samples_z = out else: samples = out samples_z = None if save_locally: save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)