AMfeta99 commited on
Commit
5ace3a2
·
verified ·
1 Parent(s): 6305440

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -1
app.py CHANGED
@@ -4,6 +4,8 @@ import gradio as gr
4
  from smolagents import CodeAgent, InferenceClientModel
5
  from smolagents import DuckDuckGoSearchTool, Tool
6
  from huggingface_hub import InferenceClient
 
 
7
 
8
  # =========================================================
9
  # Utility functions
@@ -76,7 +78,7 @@ class WrappedTextToImageTool(Tool):
76
  '''
77
  from huggingface_hub import InferenceClient
78
 
79
-
80
  class TextToImageTool(Tool):
81
  description = "This tool creates an image according to a prompt, which is a text description."
82
  name = "image_generator"
@@ -88,6 +90,32 @@ class TextToImageTool(Tool):
88
 
89
  def forward(self, prompt):
90
  return self.client.text_to_image(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
  # =========================================================
 
4
  from smolagents import CodeAgent, InferenceClientModel
5
  from smolagents import DuckDuckGoSearchTool, Tool
6
  from huggingface_hub import InferenceClient
7
+ from diffusers import DiffusionPipeline
8
+ import torch
9
 
10
  # =========================================================
11
  # Utility functions
 
78
  '''
79
  from huggingface_hub import InferenceClient
80
 
81
+ '''
82
  class TextToImageTool(Tool):
83
  description = "This tool creates an image according to a prompt, which is a text description."
84
  name = "image_generator"
 
90
 
91
  def forward(self, prompt):
92
  return self.client.text_to_image(prompt)
93
+ '''
94
+
95
+ class TextToImageTool(Tool):
96
+ description = "This tool creates an image according to a prompt. Add details like 'high-res, photorealistic'."
97
+ name = "image_generator"
98
+ inputs = {
99
+ "prompt": {
100
+ "type": "string",
101
+ "description": "The image generation prompt"
102
+ }
103
+ }
104
+ output_type = "image"
105
+
106
+ def __init__(self):
107
+ super().__init__()
108
+ dtype = torch.bfloat16
109
+ device = "cuda" if torch.cuda.is_available() else "cpu"
110
+ print(f"Using device: {device}")
111
+ self.pipe = DiffusionPipeline.from_pretrained(
112
+ "black-forest-labs/FLUX.1-schnell",
113
+ torch_dtype=dtype
114
+ ).to(device)
115
+
116
+ def forward(self, prompt):
117
+ image = self.pipe(prompt).images[0]
118
+ return image
119
 
120
 
121
  # =========================================================