Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +16 -6
- lcm_txt2img/pipeline.py +1 -1
app.py
CHANGED
@@ -37,7 +37,7 @@ if TORCH_COMPILE:
|
|
37 |
|
38 |
def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
|
39 |
torch.manual_seed(seed)
|
40 |
-
|
41 |
prompt1=prompt1,
|
42 |
prompt2=prompt2,
|
43 |
sv=merge_ratio,
|
@@ -48,12 +48,19 @@ def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231
|
|
48 |
guidance_scale=guidance,
|
49 |
lcm_origin_steps=50,
|
50 |
output_type="pil",
|
51 |
-
return_dict=False,
|
52 |
)
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
|
56 |
-
css="""
|
57 |
#container{
|
58 |
margin: 0 auto;
|
59 |
max-width: 80rem;
|
@@ -74,7 +81,8 @@ with gr.Blocks(css=css) as demo:
|
|
74 |
RTSD leverages the expertise provided by Latent Consistency Models (LCM). For more information about LCM,
|
75 |
visit their website at [Latent Consistency Models](https://latent-consistency-models.github.io/).
|
76 |
|
77 |
-
""",
|
|
|
78 |
)
|
79 |
with gr.Row():
|
80 |
with gr.Column():
|
@@ -90,7 +98,9 @@ with gr.Blocks(css=css) as demo:
|
|
90 |
sharpness = gr.Slider(
|
91 |
value=1.0, minimum=0, maximum=1, step=0.001, label="Sharpness"
|
92 |
)
|
93 |
-
seed = gr.Slider(
|
|
|
|
|
94 |
prompt1 = gr.Textbox(label="Prompt 1")
|
95 |
prompt2 = gr.Textbox(label="Prompt 2")
|
96 |
generate_bt = gr.Button("Generate")
|
|
|
37 |
|
38 |
def predict(prompt1, prompt2, merge_ratio, guidance, steps, sharpness, seed=1231231):
|
39 |
torch.manual_seed(seed)
|
40 |
+
results = pipe(
|
41 |
prompt1=prompt1,
|
42 |
prompt2=prompt2,
|
43 |
sv=merge_ratio,
|
|
|
48 |
guidance_scale=guidance,
|
49 |
lcm_origin_steps=50,
|
50 |
output_type="pil",
|
51 |
+
# return_dict=False,
|
52 |
)
|
53 |
+
nsfw_content_detected = (
|
54 |
+
results.nsfw_content_detected[0]
|
55 |
+
if "nsfw_content_detected" in results
|
56 |
+
else False
|
57 |
+
)
|
58 |
+
if nsfw_content_detected:
|
59 |
+
raise gr.Error("NSFW content detected. Please try another prompt.")
|
60 |
+
return results.images[0]
|
61 |
|
62 |
|
63 |
+
css = """
|
64 |
#container{
|
65 |
margin: 0 auto;
|
66 |
max-width: 80rem;
|
|
|
81 |
RTSD leverages the expertise provided by Latent Consistency Models (LCM). For more information about LCM,
|
82 |
visit their website at [Latent Consistency Models](https://latent-consistency-models.github.io/).
|
83 |
|
84 |
+
""",
|
85 |
+
elem_id="intro",
|
86 |
)
|
87 |
with gr.Row():
|
88 |
with gr.Column():
|
|
|
98 |
sharpness = gr.Slider(
|
99 |
value=1.0, minimum=0, maximum=1, step=0.001, label="Sharpness"
|
100 |
)
|
101 |
+
seed = gr.Slider(
|
102 |
+
randomize=True, minimum=0, maximum=12013012031030, label="Seed"
|
103 |
+
)
|
104 |
prompt1 = gr.Textbox(label="Prompt 1")
|
105 |
prompt2 = gr.Textbox(label="Prompt 2")
|
106 |
generate_bt = gr.Button("Generate")
|
lcm_txt2img/pipeline.py
CHANGED
@@ -308,7 +308,7 @@ class LatentConsistencyModelPipeline(DiffusionPipeline):
|
|
308 |
#denoised = denoised.to(prompt_embeds.dtype)
|
309 |
if not output_type == "latent":
|
310 |
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
|
311 |
-
|
312 |
has_nsfw_concept = None
|
313 |
else:
|
314 |
image = denoised
|
|
|
308 |
#denoised = denoised.to(prompt_embeds.dtype)
|
309 |
if not output_type == "latent":
|
310 |
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
|
311 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
312 |
has_nsfw_concept = None
|
313 |
else:
|
314 |
image = denoised
|