File size: 1,664 Bytes
00b7308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728f93e
021120f
00b7308
 
 
 
 
 
728f93e
 
021120f
728f93e
 
 
 
00b7308
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import csv

import torch
from diffusers import AutoPipelineForText2Image


def load_prompts(path):
    if os.path.basename(path) == 'ViLG-300.csv':
        def csv_to_dict(file_path):
            result_dict = {}
            with open(file_path, 'r', encoding='utf-8') as csv_file:
                csv_reader = csv.DictReader(csv_file, delimiter=',')
                for row in csv_reader:
                    prompt = row['\ufeffPrompt']
                    text = row['文本']
                    category = row['类别']
                    source = row['来源']
                    result_dict[prompt] = {'prompt': prompt, 'text': text, 'category': category, 'source': source}
            return result_dict
        data = csv_to_dict(path).keys()
    else:
        return NotImplementedError
    return data


def main(
    model_id="runwayml/stable-diffusion-v1-5",
    prompt_path="assets/ViLG-300.csv",
    save_path=None,
    dtype='fp16',
    variant=None,
):
    if save_path is None:
        save_path = os.path.join('saved', model_id.replace('/', '_'))
        os.makedirs(save_path, exist_ok=True)

    prompts = load_prompts(prompt_path)
    pipeline = AutoPipelineForText2Image.from_pretrained(
        model_id, 
        variant=variant, 
        torch_dtype=torch.float32 if dtype == 'fp32' else torch.float16
    )
    pipeline.to(device='cuda')
    pipeline.safety_checker = None
    for i, prompt in enumerate(prompts):
        print(f'{i}|{len(prompts)}: {prompt}')
        image = pipeline(prompt).images[0]
        image.save(os.path.join(save_path, f'{i}.jpg'))


if __name__ == '__main__':
    import fire
    fire.Fire(main)