File size: 4,293 Bytes
71d12ce
 
caef638
71d12ce
 
cf3fc03
71d12ce
 
 
 
 
6b3c1e9
c49ce5c
71d12ce
269cf5b
c49ce5c
71d12ce
 
 
b796e0c
 
942501f
71d12ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b03eeaf
71d12ce
 
 
 
 
cf75aba
71d12ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448bb9c
71d12ce
e245f8d
71d12ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2acae5e
71d12ce
 
 
 
 
2acae5e
 
71d12ce
 
 
 
 
 
2acae5e
71d12ce
 
f107a56
1cd1544
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from diffusers import StableDiffusionLDM3DPipeline
import gradio as gr
import torch 
from PIL import Image
import base64
from io import BytesIO
from tempfile import NamedTemporaryFile
from pathlib import Path

Path("tmp").mkdir(exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device is {device}")
torch_type = torch.float16 if device == "cuda" else torch.float32
pipe = StableDiffusionLDM3DPipeline.from_pretrained(
    "Intel/ldm3d-pano", 
    torch_dtype=torch_type
    # , safety_checker=None
)
pipe.to(device)
if device == "cuda":
    pipe.enable_xformers_memory_efficient_attention()
    pipe.enable_model_cpu_offload()


def get_iframe(rgb_path: str, depth_path: str, viewer_mode: str = "6DOF"):
    # buffered = BytesIO()
    # rgb.convert("RGB").save(buffered, format="JPEG")
    # rgb_base64 = base64.b64encode(buffered.getvalue())
    # buffered = BytesIO()
    # depth.convert("RGB").save(buffered, format="JPEG")
    # depth_base64 = base64.b64encode(buffered.getvalue())

    # rgb_base64 = "data:image/jpeg;base64," + rgb_base64.decode("utf-8")
    # depth_base64 = "data:image/jpeg;base64," + depth_base64.decode("utf-8")
    rgb_base64 = f"/file={rgb_path}"
    depth_base64 = f"/file={depth_path}"
    if viewer_mode == "6DOF":
        return f"""<iframe src="file=static/three6dof.html" width="100%" height="500px" data-rgb="{rgb_base64}" data-depth="{depth_base64}"></iframe>"""
    else:
        return f"""<iframe src="file=static/depthmap.html" width="100%" height="500px" data-rgb="{rgb_base64}" data-depth="{depth_base64}"></iframe>"""


def predict(
    prompt: str,
    negative_prompt: str,
    guidance_scale: float = 5.0,
    seed: int = 0,
    randomize_seed: bool = True,
):
    generator = torch.Generator() if randomize_seed else torch.manual_seed(seed)
    output = pipe(
         prompt,
        width=1024,
        height=512,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        generator=generator,
        num_inference_steps=50,
    )  # type: ignore
    rgb_image, depth_image = output.rgb[0], output.depth[0]  # type: ignore
    with NamedTemporaryFile(suffix=".png", delete=False, dir="tmp") as rgb_file:
        rgb_image.save(rgb_file.name)
        rgb_image = rgb_file.name
    with NamedTemporaryFile(suffix=".png",  delete=False,  dir="tmp") as depth_file:
        depth_image.save(depth_file.name)
        depth_image = depth_file.name

    iframe = get_iframe(rgb_image, depth_image)
    return rgb_image, depth_image, generator.seed(), iframe


with gr.Blocks() as block:
    gr.Markdown(
        """
## LDM3d Demo 

[Model card](https://huggingface.co/Intel/ldm3d-pano)
[Diffusers docs](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/ldm3d_diffusion)
For better results, specify "360 view of" or "panoramic view of" in the prompt

"""
    )
    with gr.Row():
        with gr.Column(scale=1):
            prompt = gr.Textbox(label="Prompt")
            negative_prompt = gr.Textbox(label="Negative Prompt")
            guidance_scale = gr.Slider(
                label="Guidance Scale", minimum=0, maximum=10, step=0.1, value=5.0
            )
            randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
            seed = gr.Slider(label="Seed", minimum=0,
                             maximum=2**64 - 1, step=1)
            generated_seed = gr.Number(label="Generated Seed")
            markdown = gr.Markdown(label="Output Box")
            with gr.Row():
                new_btn = gr.Button("New Image")
        with gr.Column(scale=2):
            html = gr.HTML(height='50%')
            with gr.Row():
                rgb = gr.Image(label="RGB Image", type="filepath")
                depth = gr.Image(label="Depth Image", type="filepath")
    gr.Examples(
        examples=[
            ["360 view of a large bedroom", "", 7.0, 42, False]],
        inputs=[prompt, negative_prompt, guidance_scale, seed, randomize_seed],
        outputs=[rgb, depth, generated_seed, html],
        fn=predict,
        cache_examples=True)

    new_btn.click(
        fn=predict,
        inputs=[prompt, negative_prompt, guidance_scale, seed, randomize_seed],
        outputs=[rgb, depth, generated_seed, html],
    )
    
block.launch()