SpawnedShoyo commited on
Commit
86b9624
·
verified ·
1 Parent(s): 67cbce8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -24
app.py CHANGED
@@ -4,33 +4,39 @@ from diffusers import DiffusionPipeline
4
  from PIL import Image, ImageDraw, ImageFont
5
 
6
  # Load the model (make sure to use a model that exists on Hugging Face)
7
- model = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
 
8
 
9
  def generate_image(caption):
10
  # Generate the image from the caption
11
- image = model(caption).images[0]
12
-
13
- # Convert to PIL Image for drawing
14
- image = image.convert("RGBA")
15
-
16
- # Create a draw object
17
- draw = ImageDraw.Draw(image)
18
-
19
- # Define font size and color
20
- font_size = 40
21
- font_color = "white"
22
-
23
- # Load a font
24
- font = ImageFont.load_default() # You can specify a TTF font file if needed
25
-
26
- # Calculate text size and position
27
- text_width, text_height = draw.textsize(caption, font=font)
28
- text_position = ((image.width - text_width) // 2, 10) # Centered at the top
29
-
30
- # Draw the text on the image
31
- draw.text(text_position, caption, font=font, fill=font_color)
32
-
33
- return image
 
 
 
 
 
34
 
35
  # Create the Gradio interface
36
  with gr.Blocks() as demo:
 
4
  from PIL import Image, ImageDraw, ImageFont
5
 
6
  # Load the model (make sure to use a model that exists on Hugging Face)
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ model = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(device)
9
 
10
  def generate_image(caption):
11
  # Generate the image from the caption
12
+ try:
13
+ with torch.no_grad():
14
+ image = model(caption).images[0]
15
+
16
+ # Convert to PIL Image for drawing
17
+ image = image.convert("RGBA")
18
+
19
+ # Create a draw object
20
+ draw = ImageDraw.Draw(image)
21
+
22
+ # Define font size and color
23
+ font_size = 40
24
+ font_color = "white"
25
+
26
+ # Load a font
27
+ font = ImageFont.load_default() # You can specify a TTF font file if needed
28
+
29
+ # Calculate text size and position
30
+ text_width, text_height = draw.textsize(caption, font=font)
31
+ text_position = ((image.width - text_width) // 2, 10) # Centered at the top
32
+
33
+ # Draw the text on the image
34
+ draw.text(text_position, caption, font=font, fill=font_color)
35
+
36
+ return image
37
+ except Exception as e:
38
+ print(f"Error generating image: {e}")
39
+ return None
40
 
41
  # Create the Gradio interface
42
  with gr.Blocks() as demo: