File size: 4,654 Bytes
b611dce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import os
hf_token = os.environ.get("HF_TOKEN")
import spaces
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler, AutoencoderKL
import torch
import time

class Dummy():
    pass

resolutions = ["1024 1024","1280 768","1344 768","768 1344","768 1280" ] 

# Load pipeline 

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained("briaai/BRIA-2.3", torch_dtype=torch.float16, vae=vae)
pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA")
pipe.fuse_lora()
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.to('cuda')
del vae


pipe.force_zeros_for_empty_prompt = False

# print("Optimizing BRIA 2.3 FAST LORA - this could take a while")
# t=time.time()
# pipe.unet = torch.compile(
#     pipe.unet, mode="reduce-overhead", fullgraph=True # 600 secs compilation
# )
# with torch.no_grad():
#     outputs = pipe(
#         prompt="an apple",
#         num_inference_steps=8,
#     )

#     # This will avoid future compilations on different shapes
#     unet_compiled = torch._dynamo.run(pipe.unet)
#     unet_compiled.config=pipe.unet.config
#     unet_compiled.add_embedding = Dummy()
#     unet_compiled.add_embedding.linear_1 = Dummy()
#     unet_compiled.add_embedding.linear_1.in_features = pipe.unet.add_embedding.linear_1.in_features
#     pipe.unet = unet_compiled

# print(f"Optimizing finished successfully after {time.time()-t} secs")

@spaces.GPU(enable_queue=True)
def infer(prompt,seed,resolution):
    print(f"""
    —/n
    {prompt}
    """)
    
    # generator = torch.Generator("cuda").manual_seed(555)
    t=time.time()

    if seed=="-1":
        generator=None
    else:
        try:
            seed=int(seed)
            generator = torch.Generator("cuda").manual_seed(seed)
        except:
            generator=None

    w,h = resolution.split()
    w,h = int(w),int(h)
    image = pipe(prompt,num_inference_steps=8,generator=generator,width=w,height=h,guidance_scale=0).images[0]
    print(f'gen time is {time.time()-t} secs')
    
    # Future
    # Add amound of steps
    # if nsfw:
    #     raise gr.Error("Generated image is NSFW")
    
    return image

css = """
#col-container{
    margin: 0 auto;
    max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("## BRIA 2.3 FAST LORA")
        gr.HTML('''
          <p style="margin-bottom: 10px; font-size: 94%">
            This is a demo for 
            <a href="https://huggingface.co/briaai/BRIA-2.3-FAST-LORA" target="_blank">BRIA 2.3 FAST LORA </a>. 
            This is a fast version of BRIA 2.3 text-to-image model, still trained on licensed data, and so provides full legal liability coverage for copyright and privacy infringement.
            You can also try it for free in our <a href="https://labs.bria.ai/" target="_blank">webapp demo </a>.
            Are you a startup or a student? We encourage you to apply for our 
            <a href="https://pages.bria.ai/the-visual-generative-ai-platform-for-builders-startups-plan?_gl=1*cqrl81*_ga*MTIxMDI2NzI5OC4xNjk5NTQ3MDAz*_ga_WRN60H46X4*MTcwOTM5OTMzNC4yNzguMC4xNzA5Mzk5MzM0LjYwLjAuMA..)  target="_blank">Startup Plan  </a>
            This program are designed to support emerging businesses and academic pursuits with our cutting-edge technology.
          </p>
        ''')
        with gr.Group():
            with gr.Column():
                prompt_in = gr.Textbox(label="Prompt", value="A smiling man with wavy brown hair and a trimmed beard")
                resolution = gr.Dropdown(value=resolutions[0], show_label=True, label="Resolution", choices=resolutions)
                seed = gr.Textbox(label="Seed", value=-1)
                submit_btn = gr.Button("Generate")
        result = gr.Image(label="BRIA 2.3 FAST LORA Result")

        # gr.Examples(
        #     examples = [ 
        #         "Dragon, digital art, by Greg Rutkowski",
        #         "Armored knight holding sword",
        #         "A flat roof villa near a river with black walls and huge windows",
        #         "A calm and peaceful office",
        #         "Pirate guinea pig"
        #     ],
        #     fn = infer, 
        #     inputs = [
        #         prompt_in
        #     ],
        #     outputs = [
        #         result
        #     ]
        # )

    submit_btn.click(
        fn = infer,
        inputs = [
            prompt_in,
            seed,
            resolution
        ],
        outputs = [
            result
        ]
    )

demo.queue().launch(show_api=False)