zero123-face / app.py
Junaid423's picture
created app.py
2afb799 verified
raw
history blame
3.52 kB
import os
import torch
import numpy as np
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from transformers import CLIPVisionModelWithProjection, CLIPFeatureExtractor
from diffusers.utils import load_image
from pipeline_zero1to3_stable import Zero1to3StableDiffusionPipeline, CCProjection
import math
import imageio
import gradio as gr
from PIL import Image
import cv2
# Define the background removal function
def preprocess_image(input_im):
'''
:param input_im (PIL Image).
:return input_im (H, W, 3) array in [0, 1].
'''
input_im = input_im.convert('RGB')
print("shape1 = ",input_im.size)
input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS)
input_im = np.asarray(input_im, dtype=np.float32) / 255.0
# input_im[input_im[:, :, -1] <= 0.9] = [1., 1., 1.]
return input_im
return input_im
# Load model and set paths
model_id = "mirza152/zero123-face"
cc_projection = CCProjection.from_pretrained(model_id, subfolder="cc_projection", use_safetensors=True)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", use_safetensors=True)
feature_extractor = CLIPFeatureExtractor.from_pretrained(model_id, subfolder="feature_extractor", use_safetensors=True)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_id, subfolder="image_encoder")
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler", use_safetensors=True)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", use_safetensors=True)
# Instantiate pipeline
pipe = Zero1to3StableDiffusionPipeline(
unet=unet,
cc_projection=cc_projection,
vae=vae,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
safety_checker=None,
)
pipe.enable_vae_tiling()
pipe.enable_attention_slicing()
# Define the function to process and generate GIFs
def process_image(input_image):
input_image = preprocess_image(input_image)
H, W = input_image.shape[:2]
input_image = Image.fromarray((input_image * 255.0).astype(np.uint8))
total_frames = 8
input_images = [input_image]*total_frames
pitch_range, yaw_range = 0.20, 0.20
avg_polar, avg_azimuth = 1.52, 1.57
all_poses = []
# Generate poses for GIF frames
for frame_idx in range(total_frames):
theta_target = 3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_idx / total_frames)
polar = avg_polar - theta_target
azimuth_cond = 3.14 / 2 - 0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / total_frames)
azimuth = avg_azimuth - azimuth_cond
query_pose = torch.tensor([(1.5708 - theta_target) - (1.5708 - avg_polar), math.sin(azimuth), math.cos(azimuth), 1.5708 - avg_azimuth])
all_poses.append(query_pose)
query_poses = torch.stack(all_poses)
images = pipe(input_imgs=input_images, prompt_imgs=input_images, poses=query_poses, height=H, width=W, guidance_scale=4, num_images_per_prompt=1, num_inference_steps=1).images
# Save images to GIF
gif_path = "output.gif"
images[0].save(gif_path, save_all=True, append_images=images[1:], duration=100, loop=0)
return gif_path
# Create Gradio Interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Image(type="filepath", label="Output GIF"),
title="Image to GIF Pipeline",
description="Upload an image to generate a GIF.",
allow_flagging="never",
)
iface.launch()