Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import yaml | |
import subprocess | |
import sys | |
import spaces | |
import numpy as np | |
from nsfw_detector import NSFWDetector, create_error_image | |
from PIL import Image | |
import time | |
# import logging | |
from threading import Timer | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# Global variables | |
global_model = None | |
last_use_time = None | |
unload_timer = None | |
TIMEOUT_SECONDS = 120 # 2 minutes | |
# Clone the repository | |
if not os.path.exists('Sana'): | |
subprocess.run(['git', 'clone', 'https://github.com/NVlabs/Sana.git']) | |
# Change to Sana directory | |
os.chdir('Sana') | |
# Workarounds | |
def modify_builder(): | |
builder_path = 'diffusion/model/builder.py' | |
with open(builder_path, 'r') as f: | |
content = f.readlines() | |
# Find the text_encoder_dict definition | |
for i, line in enumerate(content): | |
if 'text_encoder_dict = {' in line: | |
content.insert(i + 11, ' "unsloth-gemma-2-2b-it": "unsloth/gemma-2-2b-it",\n') | |
break | |
with open(builder_path, 'w') as f: | |
f.writelines(content) | |
def modify_config(): | |
config_path = 'configs/sana_config/1024ms/Sana_1600M_img1024.yaml' | |
with open(config_path, 'r') as f: | |
config = yaml.safe_load(f) | |
# Update text encoder | |
config['text_encoder']['text_encoder_name'] = 'unsloth-gemma-2-2b-it' | |
config['model']['mixed_precision'] = 'bf16' | |
with open(config_path, 'w') as f: | |
yaml.dump(config, f, default_flow_style=False) | |
# Run environment setup commands | |
setup_commands = [ | |
"pip install torch", # init raw torch | |
"pip install -U pip", # update pip | |
"pip install -U xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu121", # fast attn | |
"pip install pyyaml", | |
"pip install -e ." # install sana | |
] | |
for cmd in setup_commands: | |
print(f"Running: {cmd}") | |
subprocess.run(cmd.split()) | |
import torch | |
import gradio as gr | |
sys.path.append('.') | |
# Modify config and builder before importing SanaPipeline | |
modify_config() | |
modify_builder() | |
from Sana.app.sana_pipeline import SanaPipeline | |
def unload_model(): | |
global global_model, last_use_time | |
current_time = time.time() | |
if last_use_time and (current_time - last_use_time) >= TIMEOUT_SECONDS: | |
# logger.info("Unloading model due to inactivity...") | |
global_model = None | |
torch.cuda.empty_cache() | |
return "Model unloaded due to inactivity" | |
def reset_timer(): | |
global unload_timer, last_use_time | |
if unload_timer: | |
unload_timer.cancel() | |
last_use_time = time.time() | |
unload_timer = Timer(TIMEOUT_SECONDS, unload_model) | |
unload_timer.start() | |
def generate_image(prompt, height, width, guidance_scale, pag_guidance_scale, num_inference_steps): | |
global global_model | |
try: | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Load model if needed | |
if global_model is None: | |
# logger.info("Loading model...") | |
global_model = SanaPipeline("configs/sana_config/1024ms/Sana_1600M_img1024.yaml") | |
global_model.from_pretrained("hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth") | |
reset_timer() | |
# Random seed | |
generator = torch.Generator(device=device).manual_seed(int(time.time())) | |
image = global_model( | |
prompt=prompt, | |
height=height, | |
width=width, | |
guidance_scale=guidance_scale, | |
pag_guidance_scale=pag_guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
) | |
# Convert tensor to PIL Image | |
image = ((image[0] + 1) / 2).float().cpu() | |
image = (image * 255).clamp(0, 255).numpy().astype(np.uint8) | |
image = Image.fromarray(image.transpose(1, 2, 0)) | |
# Check for NSFW content | |
detector = NSFWDetector() | |
is_nsfw, category, confidence = detector.check_image(image) | |
if category == "SAFE": | |
return image | |
else: | |
# logger.warning(f"NSFW content detected ({category} with {confidence:.2f}% confidence)") | |
return create_error_image() | |
except Exception as e: | |
# logger.error(f"Error in generate_image: {str(e)}") | |
raise gr.Error(f"Generation failed: {str(e)}") | |
# Gradio Interface | |
with gr.Blocks(theme=gr.themes.Default(), css=""".center-text {text-align: center;} | |
.footer-link {text-align: center; margin: 20px 0;} | |
.slider-pad {margin-bottom: 24px;}""") as interface: | |
with gr.Row(elem_id="banner"): | |
with gr.Column(): | |
gr.Markdown("# Sana 1.6B", elem_classes="center-text") | |
gr.Markdown("Generate high-resolution images up to 4096x4096 using the Sana 1.6B model, fast.", elem_classes="center-text") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3) | |
with gr.Row(): | |
with gr.Column(): | |
height = gr.Slider(minimum=512, maximum=4096, step=64, value=1024, label="Height") | |
width = gr.Slider(minimum=512, maximum=4096, step=64, value=1024, label="Width") | |
with gr.Column(): | |
guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=5.0, label="Guidance Scale") | |
pag_guidance_scale = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, value=2.0, label="PAG Guidance Scale") | |
num_inference_steps = gr.Slider(minimum=2, maximum=50, step=1, value=18, label="Number of Steps") | |
gr.Markdown("*Note: Higher guidance scales provide stronger adherence to the prompt. PAG guidance helps with image-text alignment.*") | |
gr.Markdown("⏱️ Be patient, the model loads into memory slow first time around.") | |
generate_btn = gr.Button("Generate", variant="primary") | |
with gr.Column(scale=2): | |
output = gr.Image(label="Generated Image", height=512) | |
# Examples section | |
gr.Examples( | |
examples=[ | |
["a cyberpunk cat with a neon sign that says 'Sana'", 1024, 1024, 5.0, 2.0, 18], | |
["a beautiful sunset over a mountain landscape", 1024, 1024, 5.0, 2.0, 18], | |
["a futuristic city with flying cars", 1024, 1024, 5.0, 2.0, 18] | |
], | |
inputs=[prompt, height, width, guidance_scale, pag_guidance_scale, num_inference_steps], | |
outputs=output, | |
fn=generate_image, | |
) | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[prompt, height, width, guidance_scale, pag_guidance_scale, num_inference_steps], | |
outputs=output | |
) | |
gr.Markdown("[link to model](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px)", elem_classes="center-text footer-link") | |
# Launch the interface | |
interface.launch() |