PBRGeneration / app.py
NightRaven109's picture
Update app.py
8cc7855 verified
import os
import cv2
import torch
import numpy as np
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
from depth_anything_v2.dpt import DepthAnythingV2
# Model initialization
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
}
class NormalMapSimple:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE",),
"scale_XY": ("FLOAT",{"default": 1, "min": 0, "max": 100, "step": 0.001}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "normal_map"
CATEGORY = "image/filters"
def normal_map(self, images, scale_XY):
t = images.detach().clone().cpu().numpy().astype(np.float32)
L = np.mean(t[:,:,:,:3], axis=3)
for i in range(t.shape[0]):
t[i,:,:,0] = cv2.Scharr(L[i], -1, 1, 0, cv2.BORDER_REFLECT) * -1
t[i,:,:,1] = cv2.Scharr(L[i], -1, 0, 1, cv2.BORDER_REFLECT)
t[:,:,:,2] = 1
t = torch.from_numpy(t)
t[:,:,:,:2] *= scale_XY
t[:,:,:,:3] = torch.nn.functional.normalize(t[:,:,:,:3], dim=3) / 2 + 0.5
return (t,)
class ConvertNormals:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"normals": ("IMAGE",),
"input_mode": (["BAE", "MiDaS", "Standard", "DirectX"],),
"output_mode": (["BAE", "MiDaS", "Standard", "DirectX"],),
"scale_XY": ("FLOAT",{"default": 1, "min": 0, "max": 100, "step": 0.001}),
"normalize": ("BOOLEAN", {"default": True}),
"fix_black": ("BOOLEAN", {"default": True}),
},
"optional": {
"optional_fill": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "convert_normals"
CATEGORY = "image/filters"
def convert_normals(self, normals, input_mode, output_mode, scale_XY, normalize, fix_black, optional_fill=None):
try:
t = normals.detach().clone()
if input_mode == "BAE":
t[:,:,:,0] = 1 - t[:,:,:,0] # invert R
elif input_mode == "MiDaS":
t[:,:,:,:3] = torch.stack([1 - t[:,:,:,2], t[:,:,:,1], t[:,:,:,0]], dim=3) # BGR -> RGB and invert R
elif input_mode == "DirectX":
t[:,:,:,1] = 1 - t[:,:,:,1] # invert G
if fix_black:
key = torch.clamp(1 - t[:,:,:,2] * 2, min=0, max=1)
if optional_fill is None:
t[:,:,:,0] += key * 0.5
t[:,:,:,1] += key * 0.5
t[:,:,:,2] += key
else:
fill = optional_fill.detach().clone()
if fill.shape[1:3] != t.shape[1:3]:
fill = torch.nn.functional.interpolate(fill.movedim(-1,1), size=(t.shape[1], t.shape[2]), mode='bilinear').movedim(1,-1)
if fill.shape[0] != t.shape[0]:
fill = fill[0].unsqueeze(0).expand(t.shape[0], -1, -1, -1)
t[:,:,:,:3] += fill[:,:,:,:3] * key.unsqueeze(3).expand(-1, -1, -1, 3)
t[:,:,:,:2] = (t[:,:,:,:2] - 0.5) * scale_XY + 0.5
if normalize:
# Transform to [-1, 1] range
t_norm = t[:,:,:,:3] * 2 - 1
# Calculate the length of each vector
lengths = torch.sqrt(torch.sum(t_norm**2, dim=3, keepdim=True))
# Avoid division by zero
lengths = torch.clamp(lengths, min=1e-6)
# Normalize each vector to unit length
t_norm = t_norm / lengths
# Transform back to [0, 1] range
t[:,:,:,:3] = (t_norm + 1) / 2
if output_mode == "BAE":
t[:,:,:,0] = 1 - t[:,:,:,0] # invert R
elif output_mode == "MiDaS":
t[:,:,:,:3] = torch.stack([t[:,:,:,2], t[:,:,:,1], 1 - t[:,:,:,0]], dim=3) # invert R and BGR -> RGB
elif output_mode == "DirectX":
t[:,:,:,1] = 1 - t[:,:,:,1] # invert G
return (t,)
except Exception as e:
print(f"Error in convert_normals: {str(e)}")
return (normals,)
def get_image_intensity(img, gamma_correction=1.0):
"""
Extract intensity map from an image using HSV color space
"""
# Convert to HSV color space
result = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
# Extract Value channel (intensity)
result = result[:, :, 2].astype(np.float32) / 255.0
# Apply gamma correction
result = result ** gamma_correction
# Convert back to 0-255 range
result = (result * 255.0).clip(0, 255).astype(np.uint8)
# Convert to RGB (still grayscale but in RGB format)
result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
return result
def blend_numpy_images(image1, image2, blend_factor=0.25, mode="normal"):
"""
Blend two numpy images using normal mode
"""
# Convert to float32 and normalize to 0-1
img1 = image1.astype(np.float32) / 255.0
img2 = image2.astype(np.float32) / 255.0
# Normal blend mode
blended = img1 * (1 - blend_factor) + img2 * blend_factor
# Convert back to uint8
blended = (blended * 255.0).clip(0, 255).astype(np.uint8)
return blended
def process_normal_map(image):
"""
Process image through NormalMapSimple and ConvertNormals
"""
# Convert numpy image to torch tensor with batch dimension
image_tensor = torch.from_numpy(image).unsqueeze(0).float() / 255.0
# Create instances of the classes
normal_map_generator = NormalMapSimple()
normal_converter = ConvertNormals()
# Generate initial normal map
normal_map = normal_map_generator.normal_map(image_tensor, scale_XY=1.0)[0]
# Convert normal map from Standard to DirectX
converted_normal = normal_converter.convert_normals(
normal_map,
input_mode="Standard",
output_mode="DirectX",
scale_XY=1.0,
normalize=True,
fix_black=True
)[0]
# Convert back to numpy array
result = (converted_normal.squeeze(0).numpy() * 255).astype(np.uint8)
return result
# Download and initialize model
def initialize_model():
encoder = 'vitl'
max_depth = 1
model = DepthAnythingV2(**{**model_configs[encoder], 'max_depth': max_depth})
# Download model from private repo
model_path = hf_hub_download(
"NightRaven109/DepthAnythingv2custom",
"Modelnew100.pth",
use_auth_token=os.environ['Read']
)
# Load checkpoint
checkpoint = torch.load(model_path, map_location='cpu')
# Get state dict
state_dict = {}
for key in checkpoint.keys():
if key not in ['optimizer', 'epoch', 'previous_best']:
state_dict = checkpoint[key]
# Handle module prefix
my_state_dict = {}
for key in state_dict.keys():
new_key = key.replace('module.', '')
my_state_dict[new_key] = state_dict[key]
model.load_state_dict(my_state_dict)
return model
# Initialize model at startup
MODEL = initialize_model()
@spaces.GPU
def process_image(input_image):
"""
Process the input image and return depth map and normal map
"""
if input_image is None:
return None, None
# Move model to GPU for processing
MODEL.to('cuda')
MODEL.eval()
# Convert from RGB to BGR for depth processing
input_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
with torch.no_grad():
# Get depth map
depth = MODEL.infer_image(input_bgr)
# **Apply Gaussian Blur to smooth the depth map**
kernel_size = (15, 15) # Size of the Gaussian kernel (must be odd and positive)
sigma = 0 # If 0, sigma is calculated based on kernel size
depth = cv2.GaussianBlur(depth, kernel_size, sigma)
print(f"Applied Gaussian Blur with kernel size {kernel_size} and sigma {sigma}")
# Normalize depth for visualization (0-255)
depth_normalized = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
# Move model back to CPU
MODEL.to('cpu')
# Get intensity map
intensity_map = get_image_intensity(np.array(input_image), gamma_correction=1.0)
# Blend depth raw with intensity map
blended_result = blend_numpy_images(
cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB), # Convert depth to RGB
intensity_map,
blend_factor=0.25,
mode="normal"
)
# Generate normal map from blended result
normal_map = process_normal_map(blended_result)
return depth_normalized, normal_map
@spaces.GPU
def gradio_interface(input_img):
try:
depth_raw, normal = process_image(input_img)
return [depth_raw, normal]
except Exception as e:
print(f"Error processing image: {str(e)}")
return [None, None]
# Define interface
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(label="Input Image"),
outputs=[
gr.Image(label="Raw Depth Map"),
gr.Image(label="DirectX Normal Map")
],
title="Depth and Normal Map Generation",
description="Upload an image to generate its depth map and normal map.",
examples=[
"0269B55506557D8D_diffuse.png",
"Brick_Painted_sb0hkjp0_4K_surface_msAlbedo_baked.jpg",
"Concrete_rlvlbep0_4K_surface_msAlbedo_baked.jpg",
"Grass_Dried_scmkvwp0_4K_surface_msAlbedo_baked.jpg",
"Stone_Tile_uc2jdbpg_8K_surface_msAlbedo_baked.jpg",
"PavingStones144_1K-PNG_Color.png",
"Surface_Tiles_smgmjog_8K_surface_msAlbedo_baked.jpg",
"Panel.jpg"
]
)
# Launch the app
if __name__ == "__main__":
iface.launch()