|
import os |
|
from copy import copy |
|
from enum import Enum |
|
from typing import Tuple, List |
|
|
|
from modules import img2img, processing, shared, script_callbacks |
|
from scripts import external_code |
|
|
|
|
|
class BatchHijack: |
|
def __init__(self): |
|
self.is_batch = False |
|
self.batch_index = 0 |
|
self.batch_size = 1 |
|
self.init_seed = None |
|
self.init_subseed = None |
|
self.process_batch_callbacks = [self.on_process_batch] |
|
self.process_batch_each_callbacks = [] |
|
self.postprocess_batch_each_callbacks = [self.on_postprocess_batch_each] |
|
self.postprocess_batch_callbacks = [self.on_postprocess_batch] |
|
|
|
def img2img_process_batch_hijack(self, p, *args, **kwargs): |
|
cn_is_batch, batches, output_dir, _ = get_cn_batches(p) |
|
if not cn_is_batch: |
|
return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs) |
|
|
|
self.dispatch_callbacks(self.process_batch_callbacks, p, batches, output_dir) |
|
|
|
try: |
|
return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs) |
|
finally: |
|
self.dispatch_callbacks(self.postprocess_batch_callbacks, p) |
|
|
|
def processing_process_images_hijack(self, p, *args, **kwargs): |
|
if self.is_batch: |
|
|
|
return self.process_images_cn_batch(p, *args, **kwargs) |
|
|
|
cn_is_batch, batches, output_dir, input_file_names = get_cn_batches(p) |
|
if not cn_is_batch: |
|
|
|
return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs) |
|
|
|
output_images = [] |
|
try: |
|
self.dispatch_callbacks(self.process_batch_callbacks, p, batches, output_dir) |
|
|
|
for batch_i in range(self.batch_size): |
|
processed = self.process_images_cn_batch(p, *args, **kwargs) |
|
if shared.opts.data.get('controlnet_show_batch_images_in_ui', False): |
|
output_images.extend(processed.images[processed.index_of_first_image:]) |
|
|
|
if output_dir: |
|
self.save_images(output_dir, input_file_names[batch_i], processed.images[processed.index_of_first_image:]) |
|
|
|
if shared.state.interrupted: |
|
break |
|
|
|
finally: |
|
self.dispatch_callbacks(self.postprocess_batch_callbacks, p) |
|
|
|
if output_images: |
|
processed.images = output_images |
|
else: |
|
processed = processing.Processed(p, [], p.seed) |
|
|
|
return processed |
|
|
|
def process_images_cn_batch(self, p, *args, **kwargs): |
|
self.dispatch_callbacks(self.process_batch_each_callbacks, p) |
|
old_detectmap_output = shared.opts.data.get('control_net_no_detectmap', False) |
|
try: |
|
shared.opts.data.update({'control_net_no_detectmap': True}) |
|
processed = getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs) |
|
finally: |
|
shared.opts.data.update({'control_net_no_detectmap': old_detectmap_output}) |
|
|
|
self.dispatch_callbacks(self.postprocess_batch_each_callbacks, p, processed) |
|
|
|
|
|
if self.batch_index >= self.batch_size: |
|
shared.state.interrupted = True |
|
|
|
return processed |
|
|
|
def save_images(self, output_dir, init_image_path, output_images): |
|
os.makedirs(output_dir, exist_ok=True) |
|
for n, processed_image in enumerate(output_images): |
|
filename = os.path.basename(init_image_path) |
|
|
|
if n > 0: |
|
left, right = os.path.splitext(filename) |
|
filename = f"{left}-{n}{right}" |
|
|
|
if processed_image.mode == 'RGBA': |
|
processed_image = processed_image.convert("RGB") |
|
processed_image.save(os.path.join(output_dir, filename)) |
|
|
|
def do_hijack(self): |
|
script_callbacks.on_script_unloaded(self.undo_hijack) |
|
hijack_function( |
|
module=img2img, |
|
name='process_batch', |
|
new_name='__controlnet_original_process_batch', |
|
new_value=self.img2img_process_batch_hijack, |
|
) |
|
hijack_function( |
|
module=processing, |
|
name='process_images_inner', |
|
new_name='__controlnet_original_process_images_inner', |
|
new_value=self.processing_process_images_hijack |
|
) |
|
|
|
def undo_hijack(self): |
|
unhijack_function( |
|
module=img2img, |
|
name='process_batch', |
|
new_name='__controlnet_original_process_batch', |
|
) |
|
unhijack_function( |
|
module=processing, |
|
name='process_images_inner', |
|
new_name='__controlnet_original_process_images_inner', |
|
) |
|
|
|
def adjust_job_count(self, p): |
|
if shared.state.job_count == -1: |
|
shared.state.job_count = p.n_iter |
|
shared.state.job_count *= self.batch_size |
|
|
|
def on_process_batch(self, p, batches, output_dir, *args): |
|
print('controlnet batch mode') |
|
self.is_batch = True |
|
self.batch_index = 0 |
|
self.batch_size = len(batches) |
|
processing.fix_seed(p) |
|
if shared.opts.data.get('controlnet_increment_seed_during_batch', False): |
|
self.init_seed = p.seed |
|
self.init_subseed = p.subseed |
|
self.adjust_job_count(p) |
|
p.do_not_save_grid = True |
|
p.do_not_save_samples = bool(output_dir) |
|
|
|
def on_postprocess_batch_each(self, p, *args): |
|
self.batch_index += 1 |
|
if shared.opts.data.get('controlnet_increment_seed_during_batch', False): |
|
p.seed = p.seed + len(p.all_prompts) |
|
p.subseed = p.subseed + len(p.all_prompts) |
|
|
|
def on_postprocess_batch(self, p, *args): |
|
self.is_batch = False |
|
self.batch_index = 0 |
|
self.batch_size = 1 |
|
if shared.opts.data.get('controlnet_increment_seed_during_batch', False): |
|
p.seed = self.init_seed |
|
p.all_seeds = [self.init_seed] |
|
p.subseed = self.init_subseed |
|
p.all_subseeds = [self.init_subseed] |
|
|
|
def dispatch_callbacks(self, callbacks, *args): |
|
for callback in callbacks: |
|
callback(*args) |
|
|
|
|
|
def hijack_function(module, name, new_name, new_value): |
|
|
|
unhijack_function(module=module, name=name, new_name=new_name) |
|
setattr(module, new_name, getattr(module, name)) |
|
setattr(module, name, new_value) |
|
|
|
|
|
def unhijack_function(module, name, new_name): |
|
if hasattr(module, new_name): |
|
setattr(module, name, getattr(module, new_name)) |
|
delattr(module, new_name) |
|
|
|
|
|
class InputMode(Enum): |
|
SIMPLE = "simple" |
|
BATCH = "batch" |
|
|
|
|
|
def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[List[str]], str, List[str]]: |
|
units = external_code.get_all_units_in_processing(p) |
|
units = [copy(unit) for unit in units if getattr(unit, 'enabled', False)] |
|
any_unit_is_batch = False |
|
output_dir = '' |
|
input_file_names = [] |
|
for unit in units: |
|
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH: |
|
any_unit_is_batch = True |
|
output_dir = getattr(unit, 'output_dir', '') |
|
if isinstance(unit.batch_images, str): |
|
unit.batch_images = shared.listfiles(unit.batch_images) |
|
input_file_names = unit.batch_images |
|
|
|
if any_unit_is_batch: |
|
cn_batch_size = min(len(getattr(unit, 'batch_images', [])) |
|
for unit in units |
|
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH) |
|
else: |
|
cn_batch_size = 1 |
|
|
|
batches = [[] for _ in range(cn_batch_size)] |
|
for i in range(cn_batch_size): |
|
for unit in units: |
|
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.SIMPLE: |
|
batches[i].append(unit.image) |
|
else: |
|
batches[i].append(unit.batch_images[i]) |
|
|
|
return any_unit_is_batch, batches, output_dir, input_file_names |
|
|
|
|
|
instance = BatchHijack() |
|
|