File size: 6,689 Bytes
a0efccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import glob
import importlib
import io
import os
import sys

from PIL import Image
from scepter.modules.transform.io import pillow_convert
from scepter.modules.utils.config import Config
from scepter.modules.utils.file_system import FS

if os.path.exists('__init__.py'):
    package_name = 'scepter_ext'
    spec = importlib.util.spec_from_file_location(package_name, '__init__.py')
    package = importlib.util.module_from_spec(spec)
    sys.modules[package_name] = package
    spec.loader.exec_module(package)

from examples.examples import fft_examples as all_examples
from inference.registry import INFERENCES
fs_list = [
    Config(cfg_dict={"NAME": "HuggingfaceFs", "TEMP_DIR": "./cache"}, load=False),
    Config(cfg_dict={"NAME": "ModelscopeFs", "TEMP_DIR": "./cache"}, load=False),
    Config(cfg_dict={"NAME": "HttpFs", "TEMP_DIR": "./cache"}, load=False),
    Config(cfg_dict={"NAME": "LocalFs", "TEMP_DIR": "./cache"}, load=False),
]

for one_fs in fs_list:
    FS.init_fs_client(one_fs)


def run_one_case(pipe,
                input_image = None,
                input_mask = None,
                input_reference_image = None,
                save_path = "examples/output/example.png",
                instruction = "",
                output_h = 1024,
                output_w = 1024,
                seed = -1,
                sample_steps = None,
                guide_scale = None,
                repainting_scale = None,
                use_change=True,
                keep_pixels=True,
                keep_pixels_rate=0.8,
                **kwargs):
    if input_image is not None:
        input_image = Image.open(io.BytesIO(FS.get_object(input_image)))
        input_image = pillow_convert(input_image, "RGB")
    if input_mask is not None:
        input_mask = Image.open(io.BytesIO(FS.get_object(input_mask)))
        input_mask = pillow_convert(input_mask, "L")
    if input_reference_image is not None:
        input_reference_image = Image.open(io.BytesIO(FS.get_object(input_reference_image)))
        input_reference_image = pillow_convert(input_reference_image, "RGB")
    print(repainting_scale)
    image, _, _, _, seed = pipe(
        reference_image=input_reference_image,
        edit_image=input_image,
        edit_mask=input_mask,
        prompt=instruction,
        output_height=output_h,
        output_width=output_w,
        sampler='flow_euler',
        sample_steps=sample_steps or pipe.input.get("sample_steps", 28),
        guide_scale=guide_scale or pipe.input.get("guide_scale", 50),
        seed=seed,
        repainting_scale=repainting_scale,
        use_change=use_change,
        keep_pixels=keep_pixels,
        keep_pixels_rate=keep_pixels_rate
    )
    with FS.put_to(save_path) as local_path:
        image.save(local_path)
    return local_path, seed


def run():
    parser = argparse.ArgumentParser(description='Argparser for Scepter:\n')
    parser.add_argument('--instruction',
                        dest='instruction',
                        help='The instruction for editing or generating!',
                        default="")
    parser.add_argument('--output_h',
                        dest='output_h',
                        help='The height of output image for generation tasks!',
                        type=int,
                        default=1024)
    parser.add_argument('--output_w',
                        dest='output_w',
                        help='The width of output image for generation tasks!',
                        type=int,
                        default=1024)
    parser.add_argument('--input_reference_image',
                        dest='input_reference_image',
                        help='The input reference image!',
                        default=None
                        )
    parser.add_argument('--input_image',
                        dest='input_image',
                        help='The input image!',
                        default=None
                        )
    parser.add_argument('--input_mask',
                        dest='input_mask',
                        help='The input mask!',
                        default=None
                        )
    parser.add_argument('--save_path',
                        dest='save_path',
                        help='The save path for output image!',
                        default='examples/output_images/output.png'
                        )
    parser.add_argument('--seed',
                        dest='seed',
                        help='The seed for generation!',
                        type=int,
                        default=-1)

    parser.add_argument('--step',
                        dest='step',
                        help='The sample step for generation!',
                        type=int,
                        default=None)

    parser.add_argument('--guide_scale',
                        dest='guide_scale',
                        help='The guide scale for generation!',
                        type=int,
                        default=None)

    parser.add_argument('--repainting_scale',
                        dest='repainting_scale',
                        help='The repainting scale for content filling generation!',
                        type=int,
                        default=None)

    cfg = Config(load=True, parser_ins=parser)
    model_cfg = Config(load=True, cfg_file="config/ace_plus_fft.yaml")
    pipe = INFERENCES.build(model_cfg)


    if cfg.args.instruction == "" and cfg.args.input_image is None and cfg.args.input_reference_image is None:
        params = {
            "output_h": cfg.args.output_h,
            "output_w": cfg.args.output_w,
            "sample_steps": cfg.args.step,
            "guide_scale": cfg.args.guide_scale
        }
        # run examples

        for example in all_examples:
            example.update(params)
            local_path, seed = run_one_case(pipe, **example)

    else:
        params = {
            "input_image": cfg.args.input_image,
            "input_mask": cfg.args.input_mask,
            "input_reference_image": cfg.args.input_reference_image,
            "save_path": cfg.args.save_path,
            "instruction": cfg.args.instruction,
            "output_h": cfg.args.output_h,
            "output_w": cfg.args.output_w,
            "sample_steps": cfg.args.step,
            "guide_scale": cfg.args.guide_scale,
            "repainting_scale": cfg.args.repainting_scale,
        }
        local_path, seed = run_one_case(pipe, **params)
        print(local_path, seed)

if __name__ == '__main__':
    run()