Spaces:
Sleeping
Sleeping
File size: 4,684 Bytes
07c6a04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
from functools import partial
from typing import Any, Optional
import imageio
import torch
import videosys
from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port
class VideoSysEngine:
"""
this is partly inspired by vllm
"""
def __init__(self, config):
self.config = config
self.parallel_worker_tasks = None
self._init_worker(config.pipeline_cls)
def _init_worker(self, pipeline_cls):
world_size = self.config.world_size
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size))
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU
# contention amongst the shards
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = "1"
# NOTE: The two following lines need adaption for multi-node
assert world_size <= torch.cuda.device_count()
# change addr for multi-node
distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port())
if world_size == 1:
self.workers = []
self.worker_monitor = None
else:
result_handler = ResultHandler()
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_pipeline,
pipeline_cls=pipeline_cls,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
),
)
for rank in range(1, world_size)
]
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
self.worker_monitor.start()
self.driver_worker = self._create_pipeline(
pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method
)
# TODO: add more options here for pipeline, or wrap all options into config
def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
videosys.initialize(rank=rank, world_size=self.config.world_size, init_method=distributed_init_method, seed=42)
pipeline = pipeline_cls(self.config)
return pipeline
def _run_workers(
self,
method: str,
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
# Start the workers first.
worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers]
if async_run_tensor_parallel_workers_only:
# Just return futures
return worker_outputs
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
# Get the results of the workers.
return [driver_worker_output] + [output.get() for output in worker_outputs]
def _driver_execute_model(self, *args, **kwargs):
return self.driver_worker.generate(*args, **kwargs)
def generate(self, *args, **kwargs):
return self._run_workers("generate", *args, **kwargs)[0]
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()
def save_video(self, video, output_path):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
imageio.mimwrite(output_path, video, fps=24)
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
worker_monitor.close()
torch.distributed.destroy_process_group()
def __del__(self):
self.shutdown() |