Usage with Diffusers
#3
by
hamzao
- opened
probably not the best nor the fastest solution but it seems to be working
from diffusers import FluxPriorReduxPipeline
from transformers import SiglipVisionModel,AutoProcessor
from diffusers.pipelines.flux.modeling_flux import ReduxImageEncoder
from safetensors import safe_open
from huggingface_hub import hf_hub_download
from PIL import Image
import torch
ckpt = "google/siglip2-so400m-patch16-512"
siglip_model = SiglipVisionModel.from_pretrained(ckpt)
siglip_processor = AutoProcessor.from_pretrained(ckpt)
redux_image_encoder = ReduxImageEncoder()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
pretrained_model_path = hf_hub_download(
repo_id="ostris/Flex.1-alpha-Redux",
filename="flex1_redux_siglip2_512.safetensors",
)
with safe_open(pretrained_model_path, framework="pt") as f:
# Load weights for redux_up
if "redux_up.weight" in f.keys():
redux_image_encoder.redux_up.weight.data = f.get_tensor("redux_up.weight")
if "redux_up.bias" in f.keys():
redux_image_encoder.redux_up.bias.data = f.get_tensor("redux_up.bias")
# Load weights for redux_down
if "redux_down.weight" in f.keys():
redux_image_encoder.redux_down.weight.data = f.get_tensor("redux_down.weight")
if "redux_down.bias" in f.keys():
redux_image_encoder.redux_down.bias.data = f.get_tensor("redux_down.bias")
prior_redux = FluxPriorReduxPipeline(
image_encoder=siglip_model,
feature_extractor=siglip_processor.image_processor,
image_embedder=redux_image_encoder
).to(device).to(dtype)
prior_redux(Image.new("L",(1024,1024),"black"))