Spaces:
Runtime error
Runtime error
import gc | |
from threading import Thread | |
import torch | |
from diffusers import DDIMScheduler | |
from fastchat.utils import build_logger | |
logger = build_logger("diffusion_infer", 'diffusion_infer.log') | |
def generate_stream_lavie( | |
model, | |
tokenizer, | |
params, | |
device, | |
context_len=256, | |
stream_interval=2, | |
): | |
prompt = params["prompt"] | |
encoding = tokenizer(prompt, return_tensors="pt").to(device) | |
input_ids = encoding.input_ids | |
input_echo_len = len(input_ids) | |
logger.info(f"prompt: {prompt}") | |
# logger.info(f"model.scheduler: {model.pipe.scheduler}") | |
# logger.info(f"model.type: {type(model)}") | |
# logger.info(f"prompt: {prompt}") | |
output = model(prompt=prompt, | |
video_length=16, | |
height=360, | |
width=512, | |
num_inference_steps=50, | |
guidance_scale=7.5).video[0] | |
yield { | |
"text": output, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": 0, | |
"total_tokens": input_echo_len, | |
}, | |
"finish_reason": "stop", | |
} | |
# thread.join() | |
# clean | |
gc.collect() | |
torch.cuda.empty_cache() | |
if device == "xpu": | |
torch.xpu.empty_cache() | |
if device == "npu": | |
torch.npu.empty_cache() | |