Spaces:
Running
Running
import os | |
import csv | |
import torch | |
from diffusers import AutoPipelineForText2Image | |
def load_prompts(path): | |
if os.path.basename(path) == 'ViLG-300.csv': | |
def csv_to_dict(file_path): | |
result_dict = {} | |
with open(file_path, 'r', encoding='utf-8') as csv_file: | |
csv_reader = csv.DictReader(csv_file, delimiter=',') | |
for row in csv_reader: | |
prompt = row['\ufeffPrompt'] | |
text = row['文本'] | |
category = row['类别'] | |
source = row['来源'] | |
result_dict[prompt] = {'prompt': prompt, 'text': text, 'category': category, 'source': source} | |
return result_dict | |
data = csv_to_dict(path).keys() | |
else: | |
return NotImplementedError | |
return data | |
def main( | |
model_id="runwayml/stable-diffusion-v1-5", | |
prompt_path="assets/ViLG-300.csv", | |
save_path=None, | |
dtype='fp16', | |
variant=None, | |
): | |
if save_path is None: | |
save_path = os.path.join('saved', model_id.replace('/', '_')) | |
os.makedirs(save_path, exist_ok=True) | |
prompts = load_prompts(prompt_path) | |
pipeline = AutoPipelineForText2Image.from_pretrained( | |
model_id, | |
variant=variant, | |
torch_dtype=torch.float32 if dtype == 'fp32' else torch.float16 | |
) | |
pipeline.to(device='cuda') | |
pipeline.safety_checker = None | |
for i, prompt in enumerate(prompts): | |
print(f'{i}|{len(prompts)}: {prompt}') | |
image = pipeline(prompt).images[0] | |
image.save(os.path.join(save_path, f'{i}.jpg')) | |
if __name__ == '__main__': | |
import fire | |
fire.Fire(main) | |