File size: 3,441 Bytes
1bc457e
 
 
 
 
 
 
7fbdac4
ea5c647
1bc457e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f70725b
 
 
 
1bc457e
7fbdac4
 
1bc457e
f70725b
 
 
 
 
 
 
 
 
 
1bc457e
 
 
 
a3f5c82
1bc457e
 
 
 
 
 
 
 
 
ea5c647
 
 
 
 
 
 
 
1bc457e
ea5c647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import math
from typing import List, Optional

from PIL import Image

from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline, Img2Img
from internals.util.cache import clear_cuda_and_gc
from internals.util.config import get_base_dimension, get_is_sdxl, get_model_dir


class HighRes(AbstractPipeline):
    def load(self, img2img: Optional[Img2Img] = None):
        if hasattr(self, "pipe"):
            return

        if not img2img:
            img2img = Img2Img()
            img2img.load(get_model_dir())

        self.pipe = img2img.pipe
        self.img2img = img2img

    def apply(
        self,
        prompt: List[str],
        negative_prompt: List[str],
        images,
        width: int,
        height: int,
        num_inference_steps: int,
        strength: float = 0.5,
        guidance_scale: int = 9,
        **kwargs,
    ):
        clear_cuda_and_gc()

        images = [image.resize((width, height)) for image in images]
        kwargs = {
            "prompt": prompt,
            "image": images,
            "strength": strength,
            "negative_prompt": negative_prompt,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
            **kwargs,
        }
        result = self.pipe.__call__(**kwargs)
        return Result.from_result(result)

    @staticmethod
    def get_intermediate_dimension(target_width: int, target_height: int):
        def_size = get_base_dimension()

        desired_pixel_count = def_size * def_size
        actual_pixel_count = target_width * target_height

        scale = math.sqrt(desired_pixel_count / actual_pixel_count)

        firstpass_width = math.ceil(scale * target_width / 64) * 64
        firstpass_height = math.ceil(scale * target_height / 64) * 64

        print("Pass1", firstpass_width, firstpass_height)

        if get_is_sdxl():
            firstpass_width, firstpass_height = HighRes.find_closest_sdxl_aspect_ratio(
                firstpass_width, firstpass_height
            )

        print("Pass2", firstpass_width, firstpass_height)
        return firstpass_width, firstpass_height

    @staticmethod
    def find_closest_sdxl_aspect_ratio(target_width: int, target_height: int):
        target_ratio = target_width / target_height
        closest_ratio = ""
        min_difference = float("inf")

        for ratio_str, (width, height) in SD_XL_BASE_RATIOS.items():
            ratio = width / height
            difference = abs(target_ratio - ratio)

            if difference < min_difference:
                min_difference = difference
                closest_ratio = ratio_str

        new_width, new_height = SD_XL_BASE_RATIOS[closest_ratio]
        return new_width, new_height


SD_XL_BASE_RATIOS = {
    "0.5": (704, 1408),
    "0.52": (704, 1344),
    "0.57": (768, 1344),
    "0.6": (768, 1280),
    "0.68": (832, 1216),
    "0.72": (832, 1152),
    "0.78": (896, 1152),
    "0.82": (896, 1088),
    "0.88": (960, 1088),
    "0.94": (960, 1024),
    "1.0": (1024, 1024),
    "1.07": (1024, 960),
    "1.13": (1088, 960),
    "1.21": (1088, 896),
    "1.29": (1152, 896),
    "1.38": (1152, 832),
    "1.46": (1216, 832),
    "1.67": (1280, 768),
    "1.75": (1344, 768),
    "1.91": (1344, 704),
    "2.0": (1408, 704),
    "2.09": (1472, 704),
    "2.4": (1536, 640),
    "2.5": (1600, 640),
    "2.89": (1664, 576),
    "3.0": (1728, 576),
}