import torch import torch.nn as nn import numpy as np import os from pytorch3d.structures import Meshes from pytorch3d.renderer import ( look_at_view_transform, PerspectiveCameras, FoVPerspectiveCameras, PointLights, DirectionalLights, Materials, RasterizationSettings, MeshRenderer, MeshRasterizer, SoftPhongShader, TexturesUV, TexturesVertex, blending, ) from pytorch3d.ops import interpolate_face_attributes from pytorch3d.renderer.blending import ( BlendParams, hard_rgb_blend, sigmoid_alpha_blend, softmax_rgb_blend, ) class SoftSimpleShader(nn.Module): """ Per pixel lighting - the lighting model is applied using the interpolated coordinates and normals for each pixel. The blending function returns the soft aggregated color using all the faces per pixel. To use the default values, simply initialize the shader with the desired device e.g. """ def __init__( self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None ): super().__init__() self.lights = lights if lights is not None else PointLights(device=device) self.materials = ( materials if materials is not None else Materials(device=device) ) self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() def to(self, device): # Manually move to device modules which are not subclasses of nn.Module self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: texels = meshes.sample_textures(fragments) blend_params = kwargs.get("blend_params", self.blend_params) cameras = kwargs.get("cameras", self.cameras) if cameras is None: msg = "Cameras must be specified either at initialization \ or in the forward pass of SoftPhongShader" raise ValueError(msg) znear = kwargs.get("znear", getattr(cameras, "znear", 1.0)) zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) images = softmax_rgb_blend( texels, fragments, blend_params, znear=znear, zfar=zfar ) return images class Render_3DMM(nn.Module): def __init__( self, focal=1015, img_h=500, img_w=500, batch_size=1, device=torch.device("cuda:0"), ): super(Render_3DMM, self).__init__() self.focal = focal self.img_h = img_h self.img_w = img_w self.device = device self.renderer = self.get_render(batch_size) dir_path = os.path.dirname(os.path.realpath(__file__)) topo_info = np.load( os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True ).item() self.tris = torch.as_tensor(topo_info["tris"]).to(self.device) self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device) def compute_normal(self, geometry): vert_1 = torch.index_select(geometry, 1, self.tris[:, 0]) vert_2 = torch.index_select(geometry, 1, self.tris[:, 1]) vert_3 = torch.index_select(geometry, 1, self.tris[:, 2]) nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) tri_normal = nn.functional.normalize(nnorm, dim=2) v_norm = tri_normal[:, self.vert_tris, :].sum(2) vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2) return vert_normal def get_render(self, batch_size=1): half_s = self.img_w * 0.5 R, T = look_at_view_transform(10, 0, 0) R = R.repeat(batch_size, 1, 1) T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device) cameras = FoVPerspectiveCameras( device=self.device, R=R, T=T, znear=0.01, zfar=20, fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi, ) lights = PointLights( device=self.device, location=[[0.0, 0.0, 1e5]], ambient_color=[[1, 1, 1]], specular_color=[[0.0, 0.0, 0.0]], diffuse_color=[[0.0, 0.0, 0.0]], ) sigma = 1e-4 raster_settings = RasterizationSettings( image_size=(self.img_h, self.img_w), blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0, faces_per_pixel=2, perspective_correct=False, ) blend_params = blending.BlendParams(background_color=[0, 0, 0]) renderer = MeshRenderer( rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras), shader=SoftSimpleShader( lights=lights, blend_params=blend_params, cameras=cameras ), ) return renderer.to(self.device) @staticmethod def Illumination_layer(face_texture, norm, gamma): n_b, num_vertex, _ = face_texture.size() n_v_full = n_b * num_vertex gamma = gamma.view(-1, 3, 9).clone() gamma[:, :, 0] += 0.8 gamma = gamma.permute(0, 2, 1) a0 = np.pi a1 = 2 * np.pi / np.sqrt(3.0) a2 = 2 * np.pi / np.sqrt(8.0) c0 = 1 / np.sqrt(4 * np.pi) c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi) c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi) d0 = 0.5 / np.sqrt(3.0) Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0 norm = norm.view(-1, 3) nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2] arrH = [] arrH.append(Y0) arrH.append(-a1 * c1 * ny) arrH.append(a1 * c1 * nz) arrH.append(-a1 * c1 * nx) arrH.append(a2 * c2 * nx * ny) arrH.append(-a2 * c2 * ny * nz) arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1)) arrH.append(-a2 * c2 * nx * nz) arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2))) H = torch.stack(arrH, 1) Y = H.view(n_b, num_vertex, 9) lighting = Y.bmm(gamma) face_color = face_texture * lighting return face_color def forward(self, rott_geometry, texture, diffuse_sh): face_normal = self.compute_normal(rott_geometry) face_color = self.Illumination_layer(texture, face_normal, diffuse_sh) face_color = TexturesVertex(face_color) mesh = Meshes( rott_geometry, self.tris.float().repeat(rott_geometry.shape[0], 1, 1), face_color, ) rendered_img = self.renderer(mesh) rendered_img = torch.clamp(rendered_img, 0, 255) return rendered_img