dftrztxi / app.py
Geek7's picture
Update app.py
4761a6d verified
raw
history blame
1.66 kB
import gradio as gr
from diffusers import DiffusionPipeline
import dask
from dask import delayed, compute
from concurrent.futures import ThreadPoolExecutor
import os
os.environ['HF_HOME'] = '/blabla/cache/'
# Load model
pipe = DiffusionPipeline.from_pretrained("prompthero/openjourney-v4")
def generate_image(prompt, num_inference_steps=50):
"""
Generate an image based on a text prompt using diffusion with optimizations.
The number of inference steps is reduced for faster generation.
"""
# Reduce steps for faster processing
image = pipe(prompt, num_inference_steps=num_inference_steps).images[0]
return image
@delayed
def dask_generate(prompt):
return generate_image(prompt)
def parallel_generate(prompt):
# Use multithreading to speed up the computation by processing multiple images simultaneously
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(dask_generate, prompt) for _ in range(4)] # Example with 4 threads
results = [future.result() for future in futures]
# Execute the generation using Dask to potentially improve processing speed
images = compute(*results)
return images[0] # Return the first image generated for simplicity
# Gradio interface
iface = gr.Interface(
fn=parallel_generate,
inputs=gr.Textbox(label="Prompt", placeholder="Enter your prompt here"),
outputs=gr.Image(type="pil"),
title="Multithreaded CPU Optimized Image Generation",
description="Enter a prompt to generate an image efficiently using CPU optimization and multithreading."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()