Usage with Diffusers

#3
by hamzao - opened

probably not the best nor the fastest solution but it seems to be working

from diffusers import FluxPriorReduxPipeline
from transformers import SiglipVisionModel,AutoProcessor
from diffusers.pipelines.flux.modeling_flux import ReduxImageEncoder
from safetensors import safe_open
from huggingface_hub import hf_hub_download
from PIL import Image
import torch

ckpt = "google/siglip2-so400m-patch16-512"
siglip_model = SiglipVisionModel.from_pretrained(ckpt)
siglip_processor = AutoProcessor.from_pretrained(ckpt)
redux_image_encoder = ReduxImageEncoder()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

pretrained_model_path = hf_hub_download(
    repo_id="ostris/Flex.1-alpha-Redux",
    filename="flex1_redux_siglip2_512.safetensors",
)

with safe_open(pretrained_model_path, framework="pt")  as f:
  # Load weights for redux_up
  if "redux_up.weight" in f.keys():
      redux_image_encoder.redux_up.weight.data = f.get_tensor("redux_up.weight")
  if "redux_up.bias" in f.keys():
      redux_image_encoder.redux_up.bias.data = f.get_tensor("redux_up.bias")
  # Load weights for redux_down
  if "redux_down.weight" in f.keys():
      redux_image_encoder.redux_down.weight.data = f.get_tensor("redux_down.weight")
  if "redux_down.bias" in f.keys():
      redux_image_encoder.redux_down.bias.data = f.get_tensor("redux_down.bias")

prior_redux = FluxPriorReduxPipeline(
    image_encoder=siglip_model,
    feature_extractor=siglip_processor.image_processor,
    image_embedder=redux_image_encoder
    ).to(device).to(dtype)

prior_redux(Image.new("L",(1024,1024),"black"))
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment