############################## # ===== Standard Imports ===== ############################## import os import sys import time import random import json from math import floor from typing import Any, Dict, List, Optional, Union # Local import for default LoRA list (if available) try: from flux_app.lora import loras except ImportError: loras = [ {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""} ] import torch import numpy as np import requests from PIL import Image import spaces # Diffusers imports from diffusers import ( DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image, ) from diffusers.utils import load_image # Hugging Face Hub from huggingface_hub import ModelCard, HfFileSystem # Gradio (UI) import gradio as gr ############################## # ===== config.py ===== ############################## DTYPE = torch.bfloat16 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BASE_MODEL = "black-forest-labs/FLUX.1-dev" TAEF1_MODEL = "madebyollin/taef1" MAX_SEED = 2**32 - 1 ############################## # ===== utilities.py ===== ############################## def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.16, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps def load_image_from_path(image_path: str): """Loads an image from a given file path.""" return load_image(image_path) def randomize_seed_if_needed(randomize_seed: bool, seed: int, max_seed: int) -> int: """Randomizes the seed if requested.""" if randomize_seed: return random.randint(0, max_seed) return seed class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") ############################## # ===== enhance.py ===== ############################## def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0): """ Generates an enhanced prompt using a streaming Hugging Face API. Enhances the given prompt under 100 words without changing its essence. """ SYSTEM_PROMPT = ( "You are a prompt enhancer and your work is to enhance the given prompt under 100 words " "without changing the essence, only write the enhanced prompt and nothing else." ) timestamp = time.time() formatted_prompt = ( f"[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]" f"[INST] {message} {timestamp} [/INST]" ) api_url = "https://ruslanmv-hf-llm-api.hf.space/api/v1/chat/completions" headers = {"Content-Type": "application/json"} payload = { "model": "mixtral-8x7b", "messages": [{"role": "user", "content": formatted_prompt}], "temperature": temperature, "top_p": top_p, "max_tokens": max_new_tokens, "use_cache": False, "stream": True } try: response = requests.post(api_url, headers=headers, json=payload, stream=True) response.raise_for_status() full_output = "" for line in response.iter_lines(): if not line: continue decoded_line = line.decode("utf-8").strip() if decoded_line.startswith("data:"): decoded_line = decoded_line[len("data:"):].strip() if decoded_line == "[DONE]": break try: json_data = json.loads(decoded_line) for choice in json_data.get("choices", []): delta = choice.get("delta", {}) content = delta.get("content", "") full_output += content yield full_output if choice.get("finish_reason") == "stop": return except json.JSONDecodeError: continue except requests.exceptions.RequestException as e: yield f"Error during generation: {str(e)}" ############################## # ===== lora_handling.py ===== ############################## # A default list of LoRAs for the UI loras = [ {"image": "placeholder.jpg", "title": "Placeholder LoRA", "repo": "placeholder/repo", "weights": None, "trigger_word": ""} ] @torch.inference_mode() def flux_pipe_call_that_returns_an_iterable_of_images( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, max_sequence_length: int = 512, good_vae: Optional[Any] = None, ): height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor self.check_inputs( prompt, prompt_2, height, width, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu, ) self._num_timesteps = len(timesteps) guidance = (torch.full([1], guidance_scale, device=device, dtype=torch.float32) .expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None) for i, t in enumerate(timesteps): if self.interrupt: continue timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents_for_image, return_dict=False)[0] yield self.image_processor.postprocess(image, output_type=output_type)[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] torch.cuda.empty_cache() latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor image = good_vae.decode(latents, return_dict=False)[0] self.maybe_free_model_hooks() torch.cuda.empty_cache() yield self.image_processor.postprocess(image, output_type=output_type)[0] def get_huggingface_safetensors(link: str) -> tuple: split_link = link.split("/") if len(split_link) == 2: model_card = ModelCard.load(link) base_model = model_card.data.get("base_model") print(base_model) if base_model not in ("black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"): raise Exception("Flux LoRA Not Found!") image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None) trigger_word = model_card.data.get("instance_prompt", "") image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None fs = HfFileSystem() try: list_of_files = fs.ls(link, detail=False) for file in list_of_files: if file.endswith(".safetensors"): safetensors_name = file.split("/")[-1] if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")): image_elements = file.split("/") image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}" return split_link[1], link, safetensors_name, trigger_word, image_url except Exception as e: print(e) raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA") else: raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA") def check_custom_model(link: str) -> tuple: if link.startswith("https://"): if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"): link_split = link.split("huggingface.co/") return get_huggingface_safetensors(link_split[1]) return get_huggingface_safetensors(link) def create_lora_card(title: str, repo: str, trigger_word: str, image: str) -> str: trigger_word_info = ( f"Using: {trigger_word} as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt" ) return f'''
Loaded custom LoRA:

{title}

{trigger_word_info}
''' def add_custom_lora(custom_lora: str, loras_list: list) -> tuple: if custom_lora: try: title, repo, path, trigger_word, image = check_custom_model(custom_lora) print(f"Loaded custom LoRA: {repo}") card = create_lora_card(title, repo, trigger_word, image) existing_item_index = next((index for (index, item) in enumerate(loras_list) if item['repo'] == repo), None) if existing_item_index is None: new_item = { "image": image, "title": title, "repo": repo, "weights": path, "trigger_word": trigger_word } print(new_item) loras_list.append(new_item) existing_item_index = len(loras_list) - 1 return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word except Exception as e: print(f"Error loading LoRA: {e}") return gr.update(visible=True, value="Invalid LoRA"), gr.update(visible=False), gr.update(), "", None, "" else: return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "" def remove_custom_lora() -> tuple: return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "" def prepare_prompt(prompt: str, selected_index: Optional[int], loras_list: list) -> str: if selected_index is None: raise gr.Error("You must select a LoRA before proceeding.🧨") selected_lora = loras_list[selected_index] trigger_word = selected_lora.get("trigger_word") if trigger_word: trigger_position = selected_lora.get("trigger_position", "append") if trigger_position == "prepend": prompt_mash = f"{trigger_word} {prompt}" else: prompt_mash = f"{prompt} {trigger_word}" else: prompt_mash = prompt return prompt_mash def unload_lora_weights(pipe, pipe_i2i): if pipe is not None: pipe.unload_lora_weights() if pipe_i2i is not None: pipe_i2i.unload_lora_weights() def load_lora_weights_into_pipeline(pipe_to_use, lora_path: str, weight_name: Optional[str]): pipe_to_use.load_lora_weights( lora_path, weight_name=weight_name, low_cpu_mem_usage=True ) def update_selection(evt: gr.SelectData, width, height, loras_list): selected_lora = loras_list[evt.index] new_placeholder = f"Type a prompt for {selected_lora['title']}" lora_repo = selected_lora["repo"] updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅" if "aspect" in selected_lora: if selected_lora["aspect"] == "portrait": width = 768 height = 1024 elif selected_lora["aspect"] == "landscape": width = 1024 height = 768 else: width = 1024 height = 1024 return ( gr.update(placeholder=new_placeholder), updated_text, evt.index, width, height, ) ############################## # ===== backend.py ===== ############################## class ModelManager: def __init__(self, hf_token=None): self.hf_token = hf_token self.pipe = None self.pipe_i2i = None self.good_vae = None self.taef1 = None self.initialize_models() def initialize_models(self): """Initializes the diffusion pipelines and autoencoders.""" self.taef1 = AutoencoderTiny.from_pretrained(TAEF1_MODEL, torch_dtype=DTYPE).to(DEVICE) self.good_vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE).to(DEVICE) # Optionally, pass use_auth_token=self.hf_token if needed. self.pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, torch_dtype=DTYPE, vae=self.taef1) self.pipe = self.pipe.to(DEVICE) self.pipe_i2i = AutoPipelineForImage2Image.from_pretrained( BASE_MODEL, vae=self.good_vae, transformer=self.pipe.transformer, text_encoder=self.pipe.text_encoder, tokenizer=self.pipe.tokenizer, text_encoder_2=self.pipe.text_encoder_2, tokenizer_2=self.pipe.tokenizer_2, torch_dtype=DTYPE, ).to(DEVICE) # Instead of binding to the instance (which fails due to __slots__), # bind the custom method to the pipeline’s class. self.pipe.__class__.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images @spaces.GPU(duration=100) def generate_image(self, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale): """Generates an image using the text-to-image pipeline.""" self.pipe.to(DEVICE) generator = torch.Generator(device=DEVICE).manual_seed(seed) with calculateDuration("Generating image"): for img in self.pipe.flux_pipe_call_that_returns_an_iterable_of_images( prompt=prompt_mash, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": lora_scale}, output_type="pil", good_vae=self.good_vae, ): yield img def generate_image_to_image(self, prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed): """Generates an image using the image-to-image pipeline.""" generator = torch.Generator(device=DEVICE).manual_seed(seed) self.pipe_i2i.to(DEVICE) image_input = load_image_from_path(image_input_path) with calculateDuration("Generating image to image"): final_image = self.pipe_i2i( prompt=prompt_mash, image=image_input, strength=image_strength, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, joint_attention_kwargs={"scale": lora_scale}, output_type="pil", ).images[0] return final_image ############################## # ===== frontend.py ===== ############################## class Frontend: def __init__(self, model_manager: ModelManager): self.model_manager = model_manager self.loras = loras # Use the default LoRA list defined above. self.load_initial_loras() self.css = self.define_css() def define_css(self): # Clean and professional CSS styling. return ''' /* Title Styling */ #title { text-align: center; margin-bottom: 20px; } #title h1 { font-size: 2.5rem; margin: 0; color: #333; } /* Button and Column Styling */ #gen_btn { width: 100%; padding: 12px; font-weight: bold; border-radius: 5px; } #gen_column { display: flex; align-items: center; justify-content: center; } /* Gallery and List Styling */ #gallery .grid-wrap { margin-top: 15px; } #lora_list { background-color: #f5f5f5; padding: 10px; border-radius: 4px; font-size: 0.9rem; } .card_internal { display: flex; align-items: center; height: 100px; margin-top: 10px; } .card_internal img { margin-right: 10px; } .styler { --form-gap-width: 0px !important; } /* Progress Bar Styling */ .progress-container { width: 100%; height: 20px; background-color: #e0e0e0; border-radius: 10px; overflow: hidden; margin-bottom: 20px; } .progress-bar { height: 100%; background-color: #4f46e5; transition: width 0.3s ease-in-out; width: calc(var(--current) / var(--total) * 100%); } ''' def load_initial_loras(self): try: from lora import loras as loras_list self.loras = loras_list except ImportError: print("Warning: lora.py not found, using placeholder LoRAs.") pass @spaces.GPU(duration=100) def run_lora(self, prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer, progress=gr.Progress(track_tqdm=True)): seed = randomize_seed_if_needed(randomize_seed, seed, MAX_SEED) # Prepare the prompt using the selected LoRA trigger word. prompt_mash = prepare_prompt(prompt, selected_index, self.loras) enhanced_text = "" # Optionally enhance the prompt. if use_enhancer: for enhanced_chunk in generate(prompt_mash): enhanced_text = enhanced_chunk yield None, seed, gr.update(visible=False), enhanced_text prompt_mash = enhanced_text else: enhanced_text = "" selected_lora = self.loras[selected_index] unload_lora_weights(self.model_manager.pipe, self.model_manager.pipe_i2i) pipe_to_use = self.model_manager.pipe_i2i if image_input is not None else self.model_manager.pipe load_lora_weights_into_pipeline(pipe_to_use, selected_lora["repo"], selected_lora.get("weights")) if image_input is not None: final_image = self.model_manager.generate_image_to_image( prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed ) yield final_image, seed, gr.update(visible=False), enhanced_text else: image_generator = self.model_manager.generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale) final_image = None step_counter = 0 for image in image_generator: step_counter += 1 final_image = image progress_bar = f'
' yield image, seed, gr.update(value=progress_bar, visible=True), enhanced_text yield final_image, seed, gr.update(value=progress_bar, visible=False), enhanced_text def create_ui(self): with gr.Blocks(theme=gr.themes.Base(), css=self.css, title="Flux LoRA Generation") as app: title = gr.HTML( """

Flux LoRA Generation

""", elem_id="title", ) selected_index = gr.State(None) with gr.Row(): with gr.Column(scale=3): prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Choose the LoRA and type the prompt") with gr.Column(scale=1, elem_id="gen_column"): generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn") with gr.Row(): with gr.Column(): selected_info = gr.Markdown("") gallery = gr.Gallery( [(item["image"], item["title"]) for item in self.loras], label="LoRA Collection", allow_preview=False, columns=3, elem_id="gallery", show_share_button=False ) with gr.Group(): custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Canopus-LoRA-Flux-Anime") gr.Markdown("[Check the list of FLUX LoRA's](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list") custom_lora_info = gr.HTML(visible=False) custom_lora_button = gr.Button("Remove custom LoRA", visible=False) with gr.Column(): progress_bar = gr.Markdown(elem_id="progress", visible=False) result = gr.Image(label="Generated Image") with gr.Row(): with gr.Accordion("Advanced Settings", open=False): with gr.Row(): input_image = gr.Image(label="Input image", type="filepath") image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75) with gr.Column(): with gr.Row(): cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5) steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024) height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024) with gr.Row(): randomize_seed = gr.Checkbox(True, label="Randomize seed") seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True) lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95) with gr.Row(): use_enhancer = gr.Checkbox(value=False, label="Use Prompt Enhancer") show_enhanced_prompt = gr.Checkbox(value=False, label="Display Enhanced Prompt") enhanced_prompt_box = gr.Textbox(label="Enhanced Prompt", visible=False) gallery.select( update_selection, inputs=[width, height, gr.State(self.loras)], outputs=[prompt, selected_info, selected_index, width, height] ) custom_lora.input( add_custom_lora, inputs=[custom_lora, gr.State(self.loras)], outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt] ) custom_lora_button.click( remove_custom_lora, outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora] ) show_enhanced_prompt.change(fn=lambda show: gr.update(visible=show), inputs=show_enhanced_prompt, outputs=enhanced_prompt_box) gr.on( triggers=[generate_button.click, prompt.submit], fn=self.run_lora, inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, use_enhancer], outputs=[result, seed, progress_bar, enhanced_prompt_box] ) with gr.Row(): gr.HTML("
Credits: ruslanmv.com
") return app ############################## # ===== Main app.py ===== ############################## if __name__ == "__main__": # Get the Hugging Face token from the environment. hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise ValueError("Hugging Face token (HF_TOKEN) not found in environment variables. Please set it.") model_manager = ModelManager(hf_token=hf_token) frontend = Frontend(model_manager) app = frontend.create_ui() app.queue() # Set share=True to create a public link if desired. app.launch(share=False, debug=True)