Simple Diffusion XS

XS Size, Excess Quality

Status: training, epoch N 8

result

Example

import torch
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import AutoModel, AutoTokenizer
from PIL import Image
from tqdm.auto import tqdm
import os

def encode_prompt(prompt, negative_prompt, device, dtype):
    if negative_prompt is None:
        negative_prompt = ""

    with torch.no_grad():
        positive_inputs = tokenizer(
            prompt,
            return_tensors="pt",
            padding="max_length",
            max_length=512,
            truncation=True,
        ).to(device)
        positive_embeddings = text_model.encode_texts(
            positive_inputs.input_ids, positive_inputs.attention_mask
        )
        if positive_embeddings.ndim == 2:
            positive_embeddings = positive_embeddings.unsqueeze(1)
        positive_embeddings = positive_embeddings.to(device, dtype=dtype)
        
        negative_inputs = tokenizer(
            negative_prompt,
            return_tensors="pt",
            padding="max_length",
            max_length=512,
            truncation=True,
        ).to(device)
        negative_embeddings = text_model.encode_texts(negative_inputs.input_ids, negative_inputs.attention_mask)
        if negative_embeddings.ndim == 2:
            negative_embeddings = negative_embeddings.unsqueeze(1)
        negative_embeddings = negative_embeddings.to(device, dtype=dtype)
    return torch.cat([negative_embeddings, positive_embeddings], dim=0)

def generate_latents(embeddings, height=576, width=576, num_inference_steps=50, guidance_scale=5.5):
    with torch.no_grad():
        device, dtype = embeddings.device, embeddings.dtype
        half = embeddings.shape[0] // 2
        latent_shape = (half, 16, height // 8, width // 8)
        latents = torch.randn(latent_shape, device=device, dtype=dtype)
        embeddings = embeddings.repeat_interleave(half, dim=0)

        scheduler.set_timesteps(num_inference_steps)

        for t in tqdm(scheduler.timesteps, desc="Генерация"):
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)
            noise_pred = unet(latent_model_input, t, embeddings).sample
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
            latents = scheduler.step(noise_pred, t, latents).prev_sample
    return latents


def decode_latents(latents, vae, output_type="pil"):
    latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
    with torch.no_grad():
        images = vae.decode(latents).sample
    images = (images / 2 + 0.5).clamp(0, 1)
    images = images.cpu().permute(0, 2, 3, 1).float().numpy()
    if output_type == "pil":
        images = (images * 255).round().astype("uint8")
        images = [Image.fromarray(image) for image in images]
    return images

# Example usage:
if __name__ == "__main__":
    device = "cuda"
    dtype = torch.float16

    prompt = "кот"
    negative_prompt = "bad quality"
    tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip")
    text_model = AutoModel.from_pretrained(
        "visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True
    ).to(device, dtype=dtype).eval()
    
    embeddings = encode_prompt(prompt, negative_prompt, device, dtype)    

    pipeid = "AiArtLab/sdxs"
    variant = "fp16"
    
    unet = UNet2DConditionModel.from_pretrained(pipeid, subfolder="unet", variant=variant).to(device, dtype=dtype).eval()
    vae = AutoencoderKL.from_pretrained(pipeid, subfolder="vae", variant=variant).to(device, dtype=dtype).eval()
    scheduler = DDPMScheduler.from_pretrained(pipeid, subfolder="scheduler")


    height, width = 576, 576
    num_inference_steps = 40
    output_folder, project_name = "samples", "sdxs"
    latents = generate_latents(
        embeddings=embeddings,
        height=height,
        width=width,
        num_inference_steps = num_inference_steps
    )

    images = decode_latents(latents, vae)

    os.makedirs(output_folder, exist_ok=True)
    for idx, image in enumerate(images):
        image.save(f"{output_folder}/{project_name}_{idx}.jpg")

    print("Images generated and saved to:", output_folder)

Introduction

Fast, Lightweight & Multilingual Diffusion for Everyone

We are AiArtLab, a small team of enthusiasts with a limited budget. Our goal is to create a compact and fast model that can be trained on consumer graphics cards (full training cycle, not LoRA). We chose U-Net for its ability to efficiently handle small datasets and train quickly even on a 16GB GPU (e.g., RTX 4080). Our budget was limited to a few thousand dollars, significantly less than competitors like SDXL (tens of millions), so we decided to create a small but efficient model, similar to SD1.5 but for 2025 year.

Encoder Architecture (Text and Images)

We experimented with various encoders and concluded that large models like LLaMA or T5 XXL are unnecessary for high-quality generation. However, we needed an encoder that understands the context of the query, focusing on "prompt understanding" versus "prompt following." We chose the multilingual encoder Mexma-SigLIP, which supports 80 languages and processes sentences rather than individual tokens. Mexma accepts up to 512 tokens, creating a large matrix that slows down training. Therefore, we used a pooling layer to simplify 512x1152 matrix with plain 1x1152 vector. Specifically, we passed it through a linear model/text projector to achieve compatibility with SigLIP embeddings. This allowed us to synchronize text embeddings with images, potentially leading to a unified multimodal model. This functionality enables mixing image embeddings with textual descriptions in queries. Moreover, the model can be trained without text descriptions, using only images. This should simplify training on videos, where annotation is challenging, and achieve more consistent and seamless video generation by inputting embeddings of previous frames with decay. In the future, we aim to expand the model to 3D/video generation.

U-Net Architecture

We chose a smooth channel pyramid: [384, 576, 768, 960] with two layers per block and [4, 6, 8, 10] transformers with 1152/48=24 attention heads. This architecture provides the highest training speed with a model size of around 2 billion parameters (and fitting perfectly in my RTX 4080). We believe that due to its greater 'depth,' the quality will be on par with SDXL despite the smaller 'size.' The model can be expanded to 4 billion parameters by adding an 1152 layer, achieving perfect symmetry with the embedding size, which we value for its elegance, and probably 'Flux/MJ level' quality.

VAE Architecture

We chose an unconventional 8x 16-channel AuraDiffusion VAE, which preserves details, text, and anatomy without the 'haze' characteristic of SD3/Flux. We used a fast version with FFN convolution, observing minor texture damage on fine patterns, which may lower its rating on benchmarks. Upscalers like ESRGAN can address these artifacts. Overall, we believe this VAE is highly underrated."

Training Process

Optimizer

We tested several optimizers (AdamW, Laion, Optimi-AdamW, Adafactor, and AdamW-8bit) and chose AdamW-8bit. Optimi-AdamW demonstrated the smoothest gradient decay curve, although AdamW-8bit behaves more chaotically. However, its smaller size allows for larger batch sizes, maximizing training speed on low-cost GPUs (we used 4xA6000 and 5xL40s for training).

Learning Rate

We found that manipulating the decay/warm-up curve has an effect but is not significant. The optimal learning rate is often overestimated. Our experiments showed that Adam allows for a wide learning rate range. We started at 1e-4, gradually decreasing to 1e-6 during training. In other words, choosing the correct model architecture is far more critical than tweaking hyperparameters.

Dataset

We trained the model on approximately 1 million images: 60 epochs on ImageNet at 256 resolution (wasted time because of low-quality annotations) and 8 epochs on CaptionEmporium/midjourney-niji-1m-llavanext, plus realistic photos and anime/art at 576 resolution. We used human prompts, Caption Emporium provided prompts, WD-Tagger from SmilingWolf, and Moondream2 for annotation, varying prompt length and composition to ensure the model understands different prompting styles. The dataset is extremely small, leading the model to miss many entities and struggle with unseen concepts like 'a goose on a bicycle.' The dataset also included many waifu-style images, as we were interested in how well the model learns human anatomy rather than drawing 'The Astronaut on horseback' skills. While most descriptions were in English, our tests indicate the model is multilingual.

Limitations

  • Limited concept coverage due to the extremely small dataset.
  • The Image2Image functionality needs further training (we reduced the SigLIP portion to 5% to focus on text-to-image training).

Acknowledgments

  • Stan — Key investor. Primary financial support - thank you for believing in us when others called it madness.
  • Captainsaturnus — Material support.
  • Lovescape & Whargarbl — Moral support.
  • CaptionEmporium — Datasets.

"We believe the future lies in efficient, compact models. We are grateful for the donations and hope for your continued support."

Downloads last month
5
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support