Juno360219 commited on
Commit
2d33e3a
·
verified ·
1 Parent(s): f2aa102

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -101
app.py DELETED
@@ -1,101 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
- from huggingface_hub import hf_hub_download
5
- from safetensors.torch import load_file
6
- import spaces
7
- import os
8
- from PIL import Image
9
-
10
- SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
11
-
12
- # Constants
13
- base = "stabilityai/stable-diffusion-xl-base-1.0"
14
- repo = "ByteDance/SDXL-Lightning"
15
- checkpoints = {
16
- "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
17
- "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
18
- "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
19
- "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
20
- }
21
-
22
-
23
- # Ensure model and scheduler are initialized in GPU-enabled function
24
- if torch.cuda.is_available():
25
- pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
26
-
27
- if SAFETY_CHECKER:
28
- from safety_checker import StableDiffusionSafetyChecker
29
- from transformers import CLIPFeatureExtractor
30
-
31
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(
32
- "CompVis/stable-diffusion-safety-checker"
33
- ).to("cuda")
34
- feature_extractor = CLIPFeatureExtractor.from_pretrained(
35
- "openai/clip-vit-base-patch32"
36
- )
37
-
38
- def check_nsfw_images(
39
- images: list[Image.Image],
40
- ) -> tuple[list[Image.Image], list[bool]]:
41
- safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
42
- has_nsfw_concepts = safety_checker(
43
- images=[images],
44
- clip_input=safety_checker_input.pixel_values.to("cuda")
45
- )
46
-
47
- return images, has_nsfw_concepts
48
-
49
- # Function
50
- @spaces.GPU(enable_queue=True)
51
- def generate_image(prompt, ckpt):
52
-
53
- checkpoint = checkpoints[ckpt][0]
54
- num_inference_steps = checkpoints[ckpt][1]
55
-
56
- if num_inference_steps==1:
57
- # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
58
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
59
- else:
60
- # Ensure sampler uses "trailing" timesteps.
61
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
62
-
63
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
64
- results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
65
-
66
- if SAFETY_CHECKER:
67
- images, has_nsfw_concepts = check_nsfw_images(results.images)
68
- if any(has_nsfw_concepts):
69
- gr.Warning("NSFW content detected.")
70
- return Image.new("RGB", (512, 512))
71
- return images[0]
72
- return results.images[0]
73
-
74
-
75
-
76
- # Gradio Interface
77
- description = """
78
- This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
79
- As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
80
- """
81
-
82
- with gr.Blocks(css="style.css") as demo:
83
- gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
84
- gr.Markdown(description)
85
- with gr.Group():
86
- with gr.Row():
87
- prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
88
- ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
89
- submit = gr.Button(scale=1, variant='primary')
90
- img = gr.Image(label='SDXL-Lightning Generated Image')
91
-
92
- prompt.submit(fn=generate_image,
93
- inputs=[prompt, ckpt],
94
- outputs=img,
95
- )
96
- submit.click(fn=generate_image,
97
- inputs=[prompt, ckpt],
98
- outputs=img,
99
- )
100
-
101
- demo.queue().launch()