File size: 1,919 Bytes
fb4fac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import numpy as np
from .processors import Processor_id


class ControlNetConfigUnit:
    def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
        self.processor_id = processor_id
        self.model_path = model_path
        self.scale = scale


class ControlNetUnit:
    def __init__(self, processor, model, scale=1.0):
        self.processor = processor
        self.model = model
        self.scale = scale


class MultiControlNetManager:
    def __init__(self, controlnet_units=[]):
        self.processors = [unit.processor for unit in controlnet_units]
        self.models = [unit.model for unit in controlnet_units]
        self.scales = [unit.scale for unit in controlnet_units]

    def process_image(self, image, processor_id=None):
        if processor_id is None:
            processed_image = [processor(image) for processor in self.processors]
        else:
            processed_image = [self.processors[processor_id](image)]
        processed_image = torch.concat([
            torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
            for image_ in processed_image
        ], dim=0)
        return processed_image
    
    def __call__(
        self,
        sample, timestep, encoder_hidden_states, conditionings,
        tiled=False, tile_size=64, tile_stride=32
    ):
        res_stack = None
        for conditioning, model, scale in zip(conditionings, self.models, self.scales):
            res_stack_ = model(
                sample, timestep, encoder_hidden_states, conditioning,
                tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
            )
            res_stack_ = [res * scale for res in res_stack_]
            if res_stack is None:
                res_stack = res_stack_
            else:
                res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
        return res_stack