toto10's picture
Upload folder using huggingface_hub (#1)
34097e9
raw
history blame
8.17 kB
import os
from copy import copy
from enum import Enum
from typing import Tuple, List
from modules import img2img, processing, shared, script_callbacks
from scripts import external_code
class BatchHijack:
def __init__(self):
self.is_batch = False
self.batch_index = 0
self.batch_size = 1
self.init_seed = None
self.init_subseed = None
self.process_batch_callbacks = [self.on_process_batch]
self.process_batch_each_callbacks = []
self.postprocess_batch_each_callbacks = [self.on_postprocess_batch_each]
self.postprocess_batch_callbacks = [self.on_postprocess_batch]
def img2img_process_batch_hijack(self, p, *args, **kwargs):
cn_is_batch, batches, output_dir, _ = get_cn_batches(p)
if not cn_is_batch:
return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs)
self.dispatch_callbacks(self.process_batch_callbacks, p, batches, output_dir)
try:
return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs)
finally:
self.dispatch_callbacks(self.postprocess_batch_callbacks, p)
def processing_process_images_hijack(self, p, *args, **kwargs):
if self.is_batch:
# we are in img2img batch tab, do a single batch iteration
return self.process_images_cn_batch(p, *args, **kwargs)
cn_is_batch, batches, output_dir, input_file_names = get_cn_batches(p)
if not cn_is_batch:
# we are not in batch mode, fallback to original function
return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
output_images = []
try:
self.dispatch_callbacks(self.process_batch_callbacks, p, batches, output_dir)
for batch_i in range(self.batch_size):
processed = self.process_images_cn_batch(p, *args, **kwargs)
if shared.opts.data.get('controlnet_show_batch_images_in_ui', False):
output_images.extend(processed.images[processed.index_of_first_image:])
if output_dir:
self.save_images(output_dir, input_file_names[batch_i], processed.images[processed.index_of_first_image:])
if shared.state.interrupted:
break
finally:
self.dispatch_callbacks(self.postprocess_batch_callbacks, p)
if output_images:
processed.images = output_images
else:
processed = processing.Processed(p, [], p.seed)
return processed
def process_images_cn_batch(self, p, *args, **kwargs):
self.dispatch_callbacks(self.process_batch_each_callbacks, p)
old_detectmap_output = shared.opts.data.get('control_net_no_detectmap', False)
try:
shared.opts.data.update({'control_net_no_detectmap': True})
processed = getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
finally:
shared.opts.data.update({'control_net_no_detectmap': old_detectmap_output})
self.dispatch_callbacks(self.postprocess_batch_each_callbacks, p, processed)
# do not go past control net batch size
if self.batch_index >= self.batch_size:
shared.state.interrupted = True
return processed
def save_images(self, output_dir, init_image_path, output_images):
os.makedirs(output_dir, exist_ok=True)
for n, processed_image in enumerate(output_images):
filename = os.path.basename(init_image_path)
if n > 0:
left, right = os.path.splitext(filename)
filename = f"{left}-{n}{right}"
if processed_image.mode == 'RGBA':
processed_image = processed_image.convert("RGB")
processed_image.save(os.path.join(output_dir, filename))
def do_hijack(self):
script_callbacks.on_script_unloaded(self.undo_hijack)
hijack_function(
module=img2img,
name='process_batch',
new_name='__controlnet_original_process_batch',
new_value=self.img2img_process_batch_hijack,
)
hijack_function(
module=processing,
name='process_images_inner',
new_name='__controlnet_original_process_images_inner',
new_value=self.processing_process_images_hijack
)
def undo_hijack(self):
unhijack_function(
module=img2img,
name='process_batch',
new_name='__controlnet_original_process_batch',
)
unhijack_function(
module=processing,
name='process_images_inner',
new_name='__controlnet_original_process_images_inner',
)
def adjust_job_count(self, p):
if shared.state.job_count == -1:
shared.state.job_count = p.n_iter
shared.state.job_count *= self.batch_size
def on_process_batch(self, p, batches, output_dir, *args):
print('controlnet batch mode')
self.is_batch = True
self.batch_index = 0
self.batch_size = len(batches)
processing.fix_seed(p)
if shared.opts.data.get('controlnet_increment_seed_during_batch', False):
self.init_seed = p.seed
self.init_subseed = p.subseed
self.adjust_job_count(p)
p.do_not_save_grid = True
p.do_not_save_samples = bool(output_dir)
def on_postprocess_batch_each(self, p, *args):
self.batch_index += 1
if shared.opts.data.get('controlnet_increment_seed_during_batch', False):
p.seed = p.seed + len(p.all_prompts)
p.subseed = p.subseed + len(p.all_prompts)
def on_postprocess_batch(self, p, *args):
self.is_batch = False
self.batch_index = 0
self.batch_size = 1
if shared.opts.data.get('controlnet_increment_seed_during_batch', False):
p.seed = self.init_seed
p.all_seeds = [self.init_seed]
p.subseed = self.init_subseed
p.all_subseeds = [self.init_subseed]
def dispatch_callbacks(self, callbacks, *args):
for callback in callbacks:
callback(*args)
def hijack_function(module, name, new_name, new_value):
# restore original function in case of reload
unhijack_function(module=module, name=name, new_name=new_name)
setattr(module, new_name, getattr(module, name))
setattr(module, name, new_value)
def unhijack_function(module, name, new_name):
if hasattr(module, new_name):
setattr(module, name, getattr(module, new_name))
delattr(module, new_name)
class InputMode(Enum):
SIMPLE = "simple"
BATCH = "batch"
def get_cn_batches(p: processing.StableDiffusionProcessing) -> Tuple[bool, List[List[str]], str, List[str]]:
units = external_code.get_all_units_in_processing(p)
units = [copy(unit) for unit in units if getattr(unit, 'enabled', False)]
any_unit_is_batch = False
output_dir = ''
input_file_names = []
for unit in units:
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH:
any_unit_is_batch = True
output_dir = getattr(unit, 'output_dir', '')
if isinstance(unit.batch_images, str):
unit.batch_images = shared.listfiles(unit.batch_images)
input_file_names = unit.batch_images
if any_unit_is_batch:
cn_batch_size = min(len(getattr(unit, 'batch_images', []))
for unit in units
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH)
else:
cn_batch_size = 1
batches = [[] for _ in range(cn_batch_size)]
for i in range(cn_batch_size):
for unit in units:
if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.SIMPLE:
batches[i].append(unit.image)
else:
batches[i].append(unit.batch_images[i])
return any_unit_is_batch, batches, output_dir, input_file_names
instance = BatchHijack()