Spaces:
Runtime error
Runtime error
liuyuan-pal
commited on
Commit
•
8e2f608
1
Parent(s):
df916e6
update
Browse files- .gitignore +1 -1
- app.py +117 -104
.gitignore
CHANGED
@@ -2,4 +2,4 @@
|
|
2 |
training_examples
|
3 |
objaverse_examples
|
4 |
ldm/__pycache__/
|
5 |
-
|
|
|
2 |
training_examples
|
3 |
objaverse_examples
|
4 |
ldm/__pycache__/
|
5 |
+
__pycache__/
|
app.py
CHANGED
@@ -12,11 +12,6 @@ from ldm.util import add_margin, instantiate_from_config
|
|
12 |
from sam_utils import sam_init, sam_out_nosave
|
13 |
|
14 |
import torch
|
15 |
-
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
16 |
-
# True
|
17 |
-
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
18 |
-
# Tesla T4
|
19 |
-
|
20 |
_TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
|
21 |
_DESCRIPTION = '''
|
22 |
<div>
|
@@ -26,18 +21,24 @@ _DESCRIPTION = '''
|
|
26 |
</div>
|
27 |
Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
33 |
'''
|
34 |
-
_USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example
|
35 |
-
_USER_GUIDE1 = "Step1: Please select a
|
36 |
-
_USER_GUIDE2 = "Step2: Please choose a
|
37 |
_USER_GUIDE3 = "Generated multiview images are shown below!"
|
38 |
|
39 |
deployed = True
|
40 |
|
|
|
|
|
|
|
|
|
|
|
41 |
class BackgroundRemoval:
|
42 |
def __init__(self, device='cuda'):
|
43 |
from carvekit.api.high import HiInterface
|
@@ -74,73 +75,74 @@ def resize_inputs(image_input, crop_size):
|
|
74 |
return results
|
75 |
|
76 |
def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
if deployed:
|
92 |
-
|
93 |
else:
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
x_sample =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
else:
|
100 |
-
|
101 |
-
|
102 |
-
B, N, _, H, W = x_sample.shape
|
103 |
-
x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
|
104 |
-
x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255
|
105 |
-
x_sample = x_sample.astype(np.uint8)
|
106 |
-
|
107 |
-
results = []
|
108 |
-
for bi in range(B):
|
109 |
-
results.append(np.concatenate([x_sample[bi,ni] for ni in range(N)], 1))
|
110 |
-
results = np.concatenate(results, 0)
|
111 |
-
return Image.fromarray(results)
|
112 |
|
113 |
-
def white_background(img):
|
114 |
-
img = np.asarray(img,np.float32)/255
|
115 |
-
rgb = img[:,:,3:] * img[:,:,:3] + 1 - img[:,:,3:]
|
116 |
-
rgb = (rgb*255).astype(np.uint8)
|
117 |
-
return Image.fromarray(rgb)
|
118 |
|
119 |
def sam_predict(predictor, removal, raw_im):
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
144 |
|
145 |
def run_demo():
|
146 |
# device = f"cuda:0" if torch.cuda.is_available() else "cpu"
|
@@ -156,21 +158,28 @@ def run_demo():
|
|
156 |
model.load_state_dict(ckpt['state_dict'], strict=True)
|
157 |
model = model.cuda().eval()
|
158 |
del ckpt
|
|
|
|
|
159 |
else:
|
160 |
model = None
|
161 |
-
|
162 |
-
|
163 |
-
mask_predictor = sam_init()
|
164 |
-
removal = BackgroundRemoval()
|
165 |
-
|
166 |
-
# with open('instructions_12345.md', 'r') as f:
|
167 |
-
# article = f.read()
|
168 |
|
169 |
# NOTE: Examples must match inputs
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
# Compose demo layout & data flow.
|
176 |
with gr.Blocks(title=_TITLE, css="hf_demo/style.css") as demo:
|
@@ -182,34 +191,38 @@ def run_demo():
|
|
182 |
gr.Markdown(_DESCRIPTION)
|
183 |
|
184 |
with gr.Row(variant='panel'):
|
185 |
-
with gr.Column(scale=1):
|
186 |
-
image_block = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None, interactive=True)
|
187 |
-
guide_text = gr.Markdown(_USER_GUIDE0, visible=True)
|
188 |
gr.Examples(
|
189 |
examples=examples_full, # NOTE: elements must match inputs list!
|
190 |
-
inputs=[image_block],
|
191 |
-
outputs=[image_block],
|
192 |
cache_examples=False,
|
193 |
label='Examples (click one of the images below to start)',
|
194 |
-
examples_per_page=
|
195 |
)
|
196 |
|
|
|
|
|
|
|
|
|
197 |
|
198 |
-
|
|
|
199 |
sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
|
200 |
-
|
201 |
-
crop_btn = gr.Button('Crop
|
202 |
-
|
203 |
|
204 |
-
with gr.Column(scale=
|
205 |
input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
|
206 |
-
elevation
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
213 |
|
214 |
output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
|
215 |
|
@@ -217,9 +230,9 @@ def run_demo():
|
|
217 |
image_block.change(fn=partial(sam_predict, mask_predictor, removal), inputs=[image_block], outputs=[sam_block], queue=False)\
|
218 |
.success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
|
219 |
|
220 |
-
|
221 |
.success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
|
222 |
-
crop_btn.click(fn=resize_inputs, inputs=[sam_block,
|
223 |
.success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
|
224 |
|
225 |
run_btn.click(partial(generate, model), inputs=[batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\
|
|
|
12 |
from sam_utils import sam_init, sam_out_nosave
|
13 |
|
14 |
import torch
|
|
|
|
|
|
|
|
|
|
|
15 |
_TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
|
16 |
_DESCRIPTION = '''
|
17 |
<div>
|
|
|
21 |
</div>
|
22 |
Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss
|
23 |
|
24 |
+
Procedure:
|
25 |
+
**Step 0**. Upload an image or select an example. ==> The foreground is masked out by SAM.
|
26 |
+
**Step 1**. Select "Crop size" and click "Crop it". ==> The foreground object is centered and resized.
|
27 |
+
**Step 2**. Select "Elevation angle "and click "Run generation". ==> Generate multiview images. (This costs about 2 min.)
|
28 |
+
To reconstruct a NeRF or a 3D mesh from the generated images, please refer to our [github repository](https://github.com/liuyuan-pal/SyncDreamer).
|
29 |
'''
|
30 |
+
_USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example shown in the left)."
|
31 |
+
_USER_GUIDE1 = "Step1: Please select a **Crop size** and click **Crop it**."
|
32 |
+
_USER_GUIDE2 = "Step2: Please choose a **Elevation angle** and click **Run Generate**. This costs about 2 min."
|
33 |
_USER_GUIDE3 = "Generated multiview images are shown below!"
|
34 |
|
35 |
deployed = True
|
36 |
|
37 |
+
if deployed:
|
38 |
+
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
39 |
+
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
40 |
+
|
41 |
+
|
42 |
class BackgroundRemoval:
|
43 |
def __init__(self, device='cuda'):
|
44 |
from carvekit.api.high import HiInterface
|
|
|
75 |
return results
|
76 |
|
77 |
def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
|
78 |
+
if deployed:
|
79 |
+
seed=int(seed)
|
80 |
+
torch.random.manual_seed(seed)
|
81 |
+
np.random.seed(seed)
|
82 |
+
|
83 |
+
# prepare data
|
84 |
+
image_input = np.asarray(image_input)
|
85 |
+
image_input = image_input.astype(np.float32) / 255.0
|
86 |
+
alpha_values = image_input[:,:, 3:]
|
87 |
+
image_input[:, :, :3] = alpha_values * image_input[:,:, :3] + 1 - alpha_values # white background
|
88 |
+
image_input = image_input[:, :, :3] * 2.0 - 1.0
|
89 |
+
image_input = torch.from_numpy(image_input.astype(np.float32))
|
90 |
+
elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
|
91 |
+
data = {"input_image": image_input, "input_elevation": elevation_input}
|
92 |
+
for k, v in data.items():
|
93 |
+
if deployed:
|
94 |
+
data[k] = v.unsqueeze(0).cuda()
|
95 |
+
else:
|
96 |
+
data[k] = v.unsqueeze(0)
|
97 |
+
data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
|
98 |
+
|
99 |
if deployed:
|
100 |
+
x_sample = model.sample(data, cfg_scale, batch_view_num)
|
101 |
else:
|
102 |
+
x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
|
103 |
+
|
104 |
+
B, N, _, H, W = x_sample.shape
|
105 |
+
x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
|
106 |
+
x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255
|
107 |
+
x_sample = x_sample.astype(np.uint8)
|
108 |
+
|
109 |
+
results = []
|
110 |
+
for bi in range(B):
|
111 |
+
results.append(np.concatenate([x_sample[bi,ni] for ni in range(N)], 1))
|
112 |
+
results = np.concatenate(results, 0)
|
113 |
+
return Image.fromarray(results)
|
114 |
else:
|
115 |
+
return Image.fromarray(np.zeros([sample_num*256,16*256,3],np.uint8))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
def sam_predict(predictor, removal, raw_im):
|
119 |
+
if deployed:
|
120 |
+
raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
|
121 |
+
image_nobg = removal(raw_im.convert('RGB'))
|
122 |
+
arr = np.asarray(image_nobg)[:, :, -1]
|
123 |
+
x_nonzero = np.nonzero(arr.sum(axis=0))
|
124 |
+
y_nonzero = np.nonzero(arr.sum(axis=1))
|
125 |
+
x_min = int(x_nonzero[0].min())
|
126 |
+
y_min = int(y_nonzero[0].min())
|
127 |
+
x_max = int(x_nonzero[0].max())
|
128 |
+
y_max = int(y_nonzero[0].max())
|
129 |
+
# image_nobg.save('./nobg.png')
|
130 |
+
|
131 |
+
image_nobg.thumbnail([512, 512], Image.Resampling.LANCZOS)
|
132 |
+
image_sam = sam_out_nosave(predictor, image_nobg.convert("RGB"), (x_min, y_min, x_max, y_max))
|
133 |
+
|
134 |
+
# imsave('./mask.png', np.asarray(image_sam)[:,:,3]*255)
|
135 |
+
image_sam = np.asarray(image_sam, np.float32) / 255
|
136 |
+
out_mask = image_sam[:, :, 3:]
|
137 |
+
out_rgb = image_sam[:, :, :3] * out_mask + 1 - out_mask
|
138 |
+
out_img = (np.concatenate([out_rgb, out_mask], 2) * 255).astype(np.uint8)
|
139 |
+
|
140 |
+
image_sam = Image.fromarray(out_img, mode='RGBA')
|
141 |
+
# image_sam.save('./output.png')
|
142 |
+
torch.cuda.empty_cache()
|
143 |
+
return image_sam
|
144 |
+
else:
|
145 |
+
return raw_im
|
146 |
|
147 |
def run_demo():
|
148 |
# device = f"cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
158 |
model.load_state_dict(ckpt['state_dict'], strict=True)
|
159 |
model = model.cuda().eval()
|
160 |
del ckpt
|
161 |
+
mask_predictor = sam_init()
|
162 |
+
removal = BackgroundRemoval()
|
163 |
else:
|
164 |
model = None
|
165 |
+
mask_predictor = None
|
166 |
+
removal = None
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
# NOTE: Examples must match inputs
|
169 |
+
examples_full = [
|
170 |
+
['hf_demo/examples/basket.png',30,200],
|
171 |
+
['hf_demo/examples/cat.png',30,200],
|
172 |
+
['hf_demo/examples/crab.png',30,200],
|
173 |
+
['hf_demo/examples/elephant.png',30,200],
|
174 |
+
['hf_demo/examples/flower.png',0,200],
|
175 |
+
['hf_demo/examples/forest.png',30,200],
|
176 |
+
['hf_demo/examples/monkey.png',30,200],
|
177 |
+
['hf_demo/examples/teapot.png',0,200],
|
178 |
+
]
|
179 |
+
|
180 |
+
image_block = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None, interactive=True)
|
181 |
+
elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
|
182 |
+
crop_size = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True)
|
183 |
|
184 |
# Compose demo layout & data flow.
|
185 |
with gr.Blocks(title=_TITLE, css="hf_demo/style.css") as demo:
|
|
|
191 |
gr.Markdown(_DESCRIPTION)
|
192 |
|
193 |
with gr.Row(variant='panel'):
|
194 |
+
with gr.Column(scale=1.2):
|
|
|
|
|
195 |
gr.Examples(
|
196 |
examples=examples_full, # NOTE: elements must match inputs list!
|
197 |
+
inputs=[image_block, elevation, crop_size],
|
198 |
+
outputs=[image_block, elevation, crop_size],
|
199 |
cache_examples=False,
|
200 |
label='Examples (click one of the images below to start)',
|
201 |
+
examples_per_page=5,
|
202 |
)
|
203 |
|
204 |
+
with gr.Column(scale=0.8):
|
205 |
+
image_block.render()
|
206 |
+
guide_text = gr.Markdown(_USER_GUIDE0, visible=True)
|
207 |
+
fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
|
208 |
|
209 |
+
|
210 |
+
with gr.Column(scale=0.8):
|
211 |
sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
|
212 |
+
crop_size.render()
|
213 |
+
crop_btn = gr.Button('Crop it', variant='primary', interactive=True)
|
214 |
+
fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
|
215 |
|
216 |
+
with gr.Column(scale=0.8):
|
217 |
input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
|
218 |
+
elevation.render()
|
219 |
+
with gr.Accordion('Advanced options', open=False):
|
220 |
+
cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
|
221 |
+
sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=True, info='How many instance (16 images per instance)')
|
222 |
+
batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
|
223 |
+
seed = gr.Number(6033, label='Random seed', interactive=True)
|
224 |
+
run_btn = gr.Button('Run generation', variant='primary', interactive=True)
|
225 |
+
|
226 |
|
227 |
output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
|
228 |
|
|
|
230 |
image_block.change(fn=partial(sam_predict, mask_predictor, removal), inputs=[image_block], outputs=[sam_block], queue=False)\
|
231 |
.success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
|
232 |
|
233 |
+
crop_size.change(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
|
234 |
.success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
|
235 |
+
crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
|
236 |
.success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
|
237 |
|
238 |
run_btn.click(partial(generate, model), inputs=[batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\
|