awacke1's picture
Update app.py
02659a8
raw
history blame
9.62 kB
import base64
import datetime
import gradio as gr
import numpy as np
import os
import pytz
import psutil
import re
import random
import torch
import time
import shutil # Added for zip functionality
from PIL import Image
from io import BytesIO
from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny
try:
import intel_extension_for_pytorch as ipex
except:
pass
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# check if MPS is available OSX only M1/M2/M3 chips
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
device = torch.device(
"cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
)
torch_device = device
torch_dtype = torch.float16
# add file save and download and clear:
# Function to create a zip file from a list of files
def create_zip(files):
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
zip_filename = f"images_{timestamp}.zip"
with zipfile.ZipFile(zip_filename, 'w') as zipf:
for file in files:
zipf.write(file, os.path.basename(file))
return zip_filename
# Function to encode a file to base64
def encode_file_to_base64(file_path):
with open(file_path, "rb") as file:
encoded = base64.b64encode(file.read()).decode()
return encoded
# Function to save all images as a zip file and provide a base64 download link
def save_all_images(images):
if len(images) == 0:
return None, None
zip_filename = create_zip(images) # Create a zip file from the list of image files
zip_base64 = encode_file_to_base64(zip_filename) # Encode the zip file to base64
download_link = f'<a href="data:application/zip;base64,{zip_base64}" download="{zip_filename}">Download All</a>'
return zip_filename, download_link
# Function to clear all image files
def clear_all_images():
base_dir = os.getcwd() # Get the current base directory
img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))] # List all files ending with ".jpg" or ".jpeg"
# Remove all image files
for file in img_files:
os.remove(file)
# Add "Save All" button with emoji
save_all_button = gr.Button("💾 Save All", scale=1)
# Add "Clear All" button with emoji
clear_all_button = gr.Button("🗑️ Clear All", scale=1)
# Function to handle "Save All" button click
def save_all_button_click():
images = [file for file in os.listdir() if file.lower().endswith((".png", ".jpg", ".jpeg"))]
zip_filename, download_link = save_all_images(images)
if download_link:
gr.write(download_link)
# Function to handle "Clear All" button click
def clear_all_button_click():
clear_all_images()
# Attach click event handlers to the buttons
save_all_button.click(save_all_button_click)
clear_all_button.click(clear_all_button_click)
# Add buttons to the Streamlit app
gr.button(save_all_button)
gr.button(clear_all_button)
print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
print(f"TORCH_COMPILE: {TORCH_COMPILE}")
print(f"device: {device}")
if mps_available:
device = torch.device("mps")
torch_device = "cpu"
torch_dtype = torch.float32
if SAFETY_CHECKER == "True":
pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7")
else:
pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", safety_checker=None)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.to(device=torch_device, dtype=torch_dtype).to(device)
pipe.unet.to(memory_format=torch.channels_last)
pipe.set_progress_bar_config(disable=True)
# check if computer has less than 64GB of RAM using sys or os
if psutil.virtual_memory().total < 64 * 1024**3:
pipe.enable_attention_slicing()
if TORCH_COMPILE:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
# Load LCM LoRA
pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
pipe.fuse_lora()
def safe_filename(text):
"""Generate a safe filename from a string."""
safe_text = re.sub(r'\W+', '_', text)
timestamp = datetime.datetime.now().strftime("%Y%m%d")
return f"{safe_text}_{timestamp}.png"
def encode_image(image):
"""Encode image to base64."""
buffered = BytesIO()
#image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def fake_gan():
base_dir = os.getcwd() # Get the current base directory
img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))] # List all files ending with ".jpg" or ".jpeg"
images = [(random.choice(img_files), os.path.splitext(file)[0]) for file in img_files]
return images
def predict(prompt, guidance, steps, seed=1231231):
generator = torch.manual_seed(seed)
last_time = time.time()
results = pipe(
prompt=prompt,
generator=generator,
num_inference_steps=steps,
guidance_scale=guidance,
width=512,
height=512,
# original_inference_steps=params.lcm_steps,
output_type="pil",
)
print(f"Pipe took {time.time() - last_time} seconds")
nsfw_content_detected = (
results.nsfw_content_detected[0]
if "nsfw_content_detected" in results
else False
)
if nsfw_content_detected:
nsfw=gr.Button("🕹️NSFW🎨", scale=1)
try:
central = pytz.timezone('US/Central')
safe_date_time = datetime.datetime.now().strftime("%Y%m%d")
replaced_prompt = prompt.replace(" ", "_").replace("\n", "_")
safe_prompt = "".join(x for x in replaced_prompt if x.isalnum() or x == "_")[:90]
filename = f"{safe_date_time}_{safe_prompt}.png"
# Save the image
if len(results.images) > 0:
image_path = os.path.join("", filename) # Specify your directory
results.images[0].save(image_path)
print(f"#Image saved as {image_path}")
encoded_image = encode_image(image)
html_link = f'<a href="data:image/png;base64,{encoded_image}" download="{filename}">Download Image</a>'
#gr.Markdown(html_link)
except:
return results.images[0]
return results.images[0] if len(results.images) > 0 else None
css = """
#container{
margin: 0 auto;
max-width: 40rem;
}
#intro{
max-width: 100%;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="container"):
gr.Markdown(
"""##🕹️ Real Time 🎨 ImageGen Gallery 🌐""",
elem_id="intro",
)
with gr.Row():
with gr.Row():
prompt = gr.Textbox(
placeholder="Insert your prompt here:", scale=5, container=False
)
generate_bt = gr.Button("Generate", scale=1)
# Image Result from last prompt
image = gr.Image(type="filepath")
# Gallery of Generated Images with Image Names in Random Set to Download
with gr.Row(variant="compact"):
text = gr.Textbox(
label="Image Sets",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
)
btn = gr.Button("Generate Gallery of Saved Images")
gallery = gr.Gallery(
label="Generated Images", show_label=False, elem_id="gallery"
)
# Advanced Generate Options
with gr.Accordion("Advanced options", open=False):
guidance = gr.Slider(
label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
)
steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1)
seed = gr.Slider(
randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
)
# Diffusers
with gr.Accordion("Run with diffusers"):
gr.Markdown(
"""## Running LCM-LoRAs it with `diffusers`
```bash
pip install diffusers==0.23.0
```
```py
from diffusers import DiffusionPipeline, LCMScheduler
pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7").to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") #yes, it's a normal LoRA
results = pipe(
prompt="ImageEditor",
num_inference_steps=4,
guidance_scale=0.0,
)
results.images[0]
```
"""
)
# Function IO Eventing and Controls
inputs = [prompt, guidance, steps, seed]
generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
btn.click(fake_gan, None, gallery)
prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
guidance.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
demo.queue()
demo.launch()