encounter1997
rm pdb
38f80e4
raw
history blame
4.93 kB
from __future__ import annotations
import datetime
import os
import pathlib
import shlex
import shutil
import subprocess
import sys
import gradio as gr
import slugify
import torch
import huggingface_hub
from huggingface_hub import HfApi
from omegaconf import OmegaConf
ORIGINAL_SPACE_ID = 'BAAI/vid2vid-zero'
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
class Runner:
def __init__(self, hf_token: str | None = None):
self.hf_token = hf_token
self.checkpoint_dir = pathlib.Path('checkpoints')
self.checkpoint_dir.mkdir(exist_ok=True)
def download_base_model(self, base_model_id: str, token=None) -> str:
model_dir = self.checkpoint_dir / base_model_id
org_name = base_model_id.split('/')[0]
org_dir = self.checkpoint_dir / org_name
if not model_dir.exists():
org_dir.mkdir(exist_ok=True)
print(f'https://huggingface.co/{base_model_id}')
if token == None:
subprocess.run(shlex.split(f'git lfs install'), cwd=org_dir)
subprocess.run(shlex.split(
f'git lfs clone https://huggingface.co/{base_model_id}'),
cwd=org_dir)
return model_dir.as_posix()
else:
temp_path = huggingface_hub.snapshot_download(base_model_id, use_auth_token=token)
print(temp_path, org_dir)
# subprocess.run(shlex.split(f'mv {temp_path} {model_dir.as_posix()}'))
# return model_dir.as_posix()
return temp_path
def join_model_library_org(self, token: str) -> None:
subprocess.run(
shlex.split(
f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
))
def run_vid2vid_zero(
self,
model_path: str,
input_video: str,
prompt: str,
n_sample_frames: int,
sample_start_idx: int,
sample_frame_rate: int,
validation_prompt: str,
guidance_scale: float,
resolution: str,
seed: int,
remove_gpu_after_running: bool,
input_token: str = None,
) -> str:
if not torch.cuda.is_available():
raise gr.Error('CUDA is not available.')
if input_video is None:
raise gr.Error('You need to upload a video.')
if not prompt:
raise gr.Error('The input prompt is missing.')
if not validation_prompt:
raise gr.Error('The validation prompt is missing.')
resolution = int(resolution)
n_sample_frames = int(n_sample_frames)
sample_start_idx = int(sample_start_idx)
sample_frame_rate = int(sample_frame_rate)
repo_dir = pathlib.Path(__file__).parent
prompt_path = prompt.replace(' ', '_')
output_dir = repo_dir / 'outputs' / prompt_path
output_dir.mkdir(parents=True, exist_ok=True)
config = OmegaConf.load('configs/black-swan.yaml')
config.pretrained_model_path = self.download_base_model(model_path, token=input_token)
# we remove null-inversion & use fp16 for fast inference on web demo
config.mixed_precision = "fp16"
config.validation_data.use_null_inv = False
config.output_dir = output_dir.as_posix()
config.input_data.video_path = input_video.name # type: ignore
config.input_data.prompt = prompt
config.input_data.n_sample_frames = n_sample_frames
config.input_data.width = resolution
config.input_data.height = resolution
config.input_data.sample_start_idx = sample_start_idx
config.input_data.sample_frame_rate = sample_frame_rate
config.validation_data.prompts = [validation_prompt]
config.validation_data.video_length = 8
config.validation_data.width = resolution
config.validation_data.height = resolution
config.validation_data.num_inference_steps = 50
config.validation_data.guidance_scale = guidance_scale
config.input_batch_size = 1
config.seed = seed
config_path = output_dir / 'config.yaml'
with open(config_path, 'w') as f:
OmegaConf.save(config, f)
command = f'accelerate launch test_vid2vid_zero.py --config {config_path}'
subprocess.run(shlex.split(command))
output_video_path = os.path.join(output_dir, "sample-all.mp4")
print(f"video path for gradio: {output_video_path}")
message = 'Running completed!'
print(message)
if remove_gpu_after_running:
space_id = os.getenv('SPACE_ID')
if space_id:
api = HfApi(
token=self.hf_token if self.hf_token else input_token)
api.request_space_hardware(repo_id=space_id,
hardware='cpu-basic')
return output_video_path