Spaces:
Sleeping
Sleeping
File size: 2,208 Bytes
e47c7c5 3e248d4 e47c7c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import gradio as gr
import torch
from syngen_diffusion_pipeline import SynGenDiffusionPipeline
import subprocess
def install_spacy_model(model_name):
try:
subprocess.check_call(["python", "-m", "pip", "install", "spacy"])
subprocess.check_call(["python", "-m", "spacy", "download", model_name])
except subprocess.CalledProcessError as e:
print(f"Error occurred while installing the model: {model_name}")
print(f"Error details: {str(e)}")
install_spacy_model("en_core_web_trf")
model_path = 'CompVis/stable-diffusion-v1-4'
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
pipe = SynGenDiffusionPipeline.from_pretrained(model_path).to(device)
def generate_fn(prompt, seed):
generator = torch.Generator(device.type).manual_seed(int(seed))
result = pipe(prompt=prompt, generator=generator, num_inference_steps=50)
return result['images'][0]
title = "SynGen"
description = """
This is the demo for [SynGen](https://github.com/RoyiRa/Syntax-Guided-Generation), an image synthesis approach which first syntactically analyses the prompt to identify entities and their modifiers, and then uses a novel loss function that encourages the cross-attention maps to agree with the linguistic binding reflected by the syntax. Preprint: \"Linguistic Binding in Diffusion Models: Enhancing Attribute Correspondence through Attention Map Alignment\" (arxiv link coming soon).
"""
examples = [
["a yellow flamingo and a pink sunflower", "16"],
["a yellow flamingo and a pink sunflower", "60"],
["a checkered bowl in a cluttered room", "69"],
["a checkered bowl in a cluttered room", "77"],
["a horned lion and a spotted monkey", "1269"],
["a horned lion and a spotted monkey", "9146"]
]
prompt_textbox = gr.Textbox(label="Prompt", placeholder="A yellow flamingo and a pink sunflower", lines=1)
seed_textbox = gr.Textbox(label="Seed", placeholder="42", lines=1)
output = gr.Image(label="generation")
demo = gr.Interface(fn=generate_fn, inputs=[prompt_textbox, seed_textbox], outputs=output, examples=examples,
title=title, description=description, allow_flagging=False)
demo.launch()
|