File size: 8,895 Bytes
fd5e0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
'''
render using frames in GS
inpaint with fooocus
'''
import os
import torch
import numpy as np
from PIL import Image
from copy import deepcopy
from ops.utils import *
from ops.sky import Sky_Seg_Tool
from ops.visual_check import Check
from ops.gs.train import GS_Train_Tool
from pipe.lvm_inpaint import Inpaint_Tool
from pipe.reconstruct import Reconstruct_Tool
from ops.trajs import _generate_trajectory
from ops.connect import Occlusion_Removal
from ops.gs.basic import Frame,Gaussian_Scene

from ops.mcs import HackSD_MCS
from pipe.refine_mvdps import Refinement_Tool_MCS
        
class Pipeline():
    def __init__(self,cfg) -> None:
        self.device = 'cuda'
        self.cfg = cfg
        self.sky_value = cfg.model.sky.value
        self.sky_segor = Sky_Seg_Tool(cfg)
        self.rgb_inpaintor = Inpaint_Tool(cfg)
        self.reconstructor = Reconstruct_Tool(cfg)
        # temp
        self.removalor = Occlusion_Removal()
        self.checkor = Check()

    def _mkdir(self,dir):
        if not os.path.exists(dir):
            os.makedirs(dir)

    def _resize_input(self,fn):
        resize_long_edge = int(self.cfg.scene.input.resize_long_edge)
        print(f'[Preprocess...] Resize the long edge of input image to {resize_long_edge}.')
        spl = str.rfind(fn,'.')
        backup_fn = fn[:spl] + '.original' + fn[spl:]
        rgb = Image.open(fn)
        rgb.save(backup_fn) # back up original image 
        rgb = np.array(rgb)[:,:,:3]/255.
        H,W = rgb.shape[0:2]
        if H>W:
            W = int(W*resize_long_edge/H)
            H = resize_long_edge
        else:
            H = int(H*resize_long_edge/W)
            W = resize_long_edge
        rgb = cv2.resize(rgb,(W,H))
        pic = (rgb * 255.0).clip(0, 255)
        pic_save = Image.fromarray(pic.astype(np.uint8))
        pic_save.save(fn)

    def _initialization(self,rgb):
        rgb = np.array(rgb)[:,:,:3]
        # conduct outpainting on rgb and change cu,cv
        outpaint_frame :Frame = self.rgb_inpaintor(Frame(rgb=rgb),
                                                   outpaint_selections=self.outpaint_selections,
                                                   outpaint_extend_times=self.outpaint_extend_times)
        # conduct reconstruction on outpaint results
        _,intrinsic,_ = self.reconstructor._ProDpt_(rgb) # estimate focal on input view
        metric_dpt,intrinsic,edge_msk = self.reconstructor._ProDpt_(outpaint_frame.rgb)
        outpaint_frame.intrinsic = deepcopy(intrinsic)
        # split to input and outpaint areas
        input_frame = Frame(H=rgb.shape[0],
                            W=rgb.shape[1],
                            rgb=rgb,
                            intrinsic=deepcopy(intrinsic),
                            extrinsic=np.eye(4))
        input_frame.intrinsic[0,-1] = input_frame.W/2.
        input_frame.intrinsic[1,-1] = input_frame.H/2.
        # others
        input_area = ~outpaint_frame.inpaint
        input_edg = edge_msk[input_area].reshape(input_frame.H,input_frame.W)
        input_dpt = metric_dpt[input_area].reshape(input_frame.H,input_frame.W)
        sky = self.sky_segor(input_frame.rgb)
        input_frame.sky = sky
        input_dpt[sky] = self.sky_value
        input_frame.dpt = input_dpt
        input_frame.inpaint = np.ones_like(input_edg,bool) & (~sky)
        input_frame.inpaint_wo_edge = (~input_edg) & (~sky)
        input_frame.ideal_dpt = deepcopy(input_dpt)
        input_frame.prompt = outpaint_frame.prompt
        # outpaint frame
        sky = self.sky_segor(outpaint_frame.rgb)
        outpaint_frame.sky = sky
        metric_dpt[sky] = self.sky_value
        outpaint_frame.dpt = metric_dpt
        outpaint_frame.ideal_dpt = deepcopy(metric_dpt)
        outpaint_frame.inpaint = (outpaint_frame.inpaint)&(~sky)
        outpaint_frame.inpaint_wo_edge = (outpaint_frame.inpaint)&(~edge_msk)
        # add init frame
        self.scene._add_trainable_frame(input_frame,require_grad=True)
        self.scene._add_trainable_frame(outpaint_frame,require_grad=True)
        self.scene = GS_Train_Tool(self.scene,iters=100)(self.scene.frames)
    
    def _generate_traj(self):
        self.dense_trajs = _generate_trajectory(self.cfg,self.scene)
        
    def _pose_to_frame(self,extrinsic,margin=32):
        H = self.scene.frames[0].H + margin
        W = self.scene.frames[0].W + margin
        prompt = self.scene.frames[-1].prompt
        intrinsic = deepcopy(self.scene.frames[0].intrinsic)
        intrinsic[0,-1], intrinsic[1,-1] = W/2, H/2
        frame = Frame(H=H,W=W,intrinsic=intrinsic,extrinsic=extrinsic,prompt=prompt)
        frame = self.scene._render_for_inpaint(frame)  
        return frame
      
    def _next_frame(self,margin=32):
        # select the frame with largest holes but less than 60% 
        inpaint_area_ratio = []
        for pose in self.dense_trajs:
            temp_frame = self._pose_to_frame(pose,margin)
            inpaint_mask = temp_frame.inpaint 
            inpaint_area_ratio.append(np.mean(inpaint_mask))
        inpaint_area_ratio = np.array(inpaint_area_ratio)
        inpaint_area_ratio[inpaint_area_ratio > 0.6] = 0.
        # remove adjustancy frames
        for s in self.select_frames:
            inpaint_area_ratio[s] = 0.
            if s-1>-1:
                inpaint_area_ratio[s-1] = 0.
            if s+1<len(self.dense_trajs):
                inpaint_area_ratio[s+1] = 0.
        # select the largest ones
        select = np.argmax(inpaint_area_ratio)
        if inpaint_area_ratio[select] < 0.0001: return None
        self.select_frames.append(select)
        pose = self.dense_trajs[select]
        frame = self._pose_to_frame(pose,margin)
        return frame   

    def _inpaint_next_frame(self,margin=32):
        frame = self._next_frame(margin)
        if frame is None: return None
        # inpaint rgb
        frame = self.rgb_inpaintor(frame)
        # inpaint dpt
        connect_dpt,metric_dpt,_,edge_msk = self.reconstructor._Guide_ProDpt_(frame.rgb,frame.intrinsic,frame.dpt,~frame.inpaint)
        frame.dpt = connect_dpt
        frame = self.removalor(self.scene,frame)
        sky = self.sky_segor(frame.rgb)
        frame.sky = sky
        frame.dpt[sky] = self.sky_value
        frame.inpaint = (frame.inpaint) & (~sky)
        frame.inpaint_wo_edge = (frame.inpaint) & (~edge_msk)
        # determine target depth and normal
        frame.ideal_dpt = metric_dpt
        self.scene._add_trainable_frame(frame)
        return 0

    def _coarse_scene(self,rgb):
        self._initialization(rgb)
        self._generate_traj()
        self.select_frames = []
        for i in range(self.n_sample-2):
            print(f'Procecssing {i+2}/{self.n_sample} frame...')
            sign = self._inpaint_next_frame()
            if sign is None: break
            self.scene = GS_Train_Tool(self.scene,iters=self.opt_iters_per_frame)(self.scene.frames)

    def _MCS_Refinement(self):
        refiner = HackSD_MCS(device='cuda',use_lcm=True,denoise_steps=self.mcs_iterations,
                             sd_ckpt=self.cfg.model.optimize.sd,
                             lcm_ckpt=self.cfg.model.optimize.lcm)
        self.MVDPS = Refinement_Tool_MCS(self.scene,device='cuda',
                                         refiner=refiner,
                                         traj_type=self.traj_type,
                                         n_view=self.mcs_n_view,
                                         rect_w=self.mcs_rect_w,
                                         n_gsopt_iters=self.mcs_gsopt_per_frame)
        self.scene = self.MVDPS()
        refiner.to('cpu')

    def __call__(self):
        rgb_fn = self.cfg.scene.input.rgb
        # coarse
        self.scene = Gaussian_Scene(self.cfg)
        # for trajectory genearation
        self.n_sample = self.cfg.scene.traj.n_sample
        self.traj_type = self.cfg.scene.traj.traj_type
        self.scene.traj_type = self.cfg.scene.traj.traj_type
        # for scene generation
        self.opt_iters_per_frame = self.cfg.scene.gaussian.opt_iters_per_frame
        self.outpaint_selections = self.cfg.scene.outpaint.outpaint_selections
        self.outpaint_extend_times = self.cfg.scene.outpaint.outpaint_extend_times
        # for scene refinement
        self.mcs_n_view = self.cfg.scene.mcs.n_view
        self.mcs_rect_w = self.cfg.scene.mcs.rect_w
        self.mcs_iterations = self.cfg.scene.mcs.steps
        self.mcs_gsopt_per_frame = self.cfg.scene.mcs.gsopt_iters
        # coarse scene
        self._resize_input(rgb_fn)
        dir = rgb_fn[:str.rfind(rgb_fn,'/')]
        rgb = Image.open(rgb_fn)
        self._coarse_scene(rgb)
        torch.cuda.empty_cache()
        # refinement
        self._MCS_Refinement()
        torch.save(self.scene,f'{dir}/scene.pth')
        self.checkor._render_video(self.scene,save_dir=f'{dir}/')