File size: 2,320 Bytes
f466dd9
a9b8939
9e09422
cd715cb
e0ec116
b5ad13a
a9b8939
 
6d1d03a
e0ec116
f466dd9
ccee0a8
1d0b035
e0ec116
6449f8f
 
 
c301a62
 
 
 
 
6449f8f
c301a62
f466dd9
b5ad13a
8a0f059
f466dd9
c0cd59a
6035350
1d0b035
e0ec116
 
 
 
 
6449f8f
6035350
e0ec116
b5ad13a
 
6035350
e0ec116
6035350
e0ec116
 
 
6035350
e0ec116
6035350
b5ad13a
6035350
f466dd9
 
 
d05fa5e
1adc78a
 
e0ec116
d05fa5e
6035350
f466dd9
 
 
a9b8939
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
from diffusers import AutoPipelineForText2Image
from io import BytesIO
from generate_propmts import generate_prompt
from concurrent.futures import ThreadPoolExecutor, as_completed

# Load the model once outside of the function
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")

def generate_image(prompt):
    try:
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Model output: {output}")

        # Check if the model returned images
        if isinstance(output.images, list) and len(output.images) > 0:
            image = output.images[0]
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            return image_bytes
        else:
            raise Exception("No images returned by the model.")
            
    except Exception as e:
        print(f"Error generating image: {e}")
        return None

def inference(sentence_mapping, character_dict, selected_style):
    images = []
    print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
    prompts = []

    # Generate prompts for each paragraph
    for paragraph_number, sentences in sentence_mapping.items():
        combined_sentence = " ".join(sentences)
        prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
        prompts.append(prompt)
        print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")

    with ThreadPoolExecutor() as executor:
        futures = [executor.submit(generate_image, prompt) for prompt in prompts]

        for future in as_completed(futures):
            try:
                image = future.result()
                if image:
                    images.append(image)
            except Exception as e:
                print(f"Error processing prompt: {e}")

    return images

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs=gr.Gallery(label="Generated Images")
)

if __name__ == "__main__":
    gradio_interface.launch()