Spaces:
Runtime error
Runtime error
File size: 1,340 Bytes
ec0c335 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gc
from threading import Thread
import torch
from diffusers import DDIMScheduler
from fastchat.utils import build_logger
logger = build_logger("diffusion_infer", 'diffusion_infer.log')
@torch.inference_mode()
def generate_stream_lavie(
model,
tokenizer,
params,
device,
context_len=256,
stream_interval=2,
):
prompt = params["prompt"]
encoding = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = encoding.input_ids
input_echo_len = len(input_ids)
logger.info(f"prompt: {prompt}")
# logger.info(f"model.scheduler: {model.pipe.scheduler}")
# logger.info(f"model.type: {type(model)}")
# logger.info(f"prompt: {prompt}")
output = model(prompt=prompt,
video_length=16,
height=360,
width=512,
num_inference_steps=50,
guidance_scale=7.5).video[0]
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": 0,
"total_tokens": input_echo_len,
},
"finish_reason": "stop",
}
# thread.join()
# clean
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()
|