vilarin commited on
Commit
b206729
1 Parent(s): 9aa2c8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -74
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import spaces
2
  import os
3
  import gradio as gr
4
- import torch
5
  import numpy as np
6
  import random
7
- from diffusers import FluxPipeline
 
8
  from translatepy import Translator
9
- from huggingface_hub import hf_hub_download
10
  import requests
11
  import re
12
 
@@ -14,7 +15,7 @@ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
14
  translator = Translator()
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
  # Constants
17
- model = "black-forest-labs/FLUX.1-dev"
18
  MAX_SEED = np.iinfo(np.int32).max
19
 
20
  CSS = """
@@ -30,54 +31,27 @@ JS = """function () {
30
  }
31
  }"""
32
 
33
- if torch.cuda.is_available():
34
- pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16)
35
 
36
- def scrape_lora_link(url):
37
- try:
38
- # Send a GET request to the URL
39
- response = requests.get(url)
40
- response.raise_for_status() # Raise an exception for bad status codes
41
 
42
- # Get the content of the page
43
- content = response.text
44
-
45
- # Use regular expression to find the link
46
- pattern = r'href="(.*?lora.*?\.safetensors\?download=true)"'
47
- match = re.search(pattern, content)
48
-
49
- if match:
50
- safetensors_url = match.group(1)
51
- filename = safetensors_url.split('/')[-1].split('?')[0] # Extract the filename from the URL
52
- return filename
53
- else:
54
- return None
55
-
56
- except requests.RequestException as e:
57
- print(f"An error occurred while fetching the URL: {e}")
58
- return None
59
-
60
- def enable_lora(lora_scale, lora_in, lora_add):
61
  pipe.unload_lora_weights()
62
  if not lora_in and not lora_add:
63
- return
64
  else:
65
  if lora_add:
66
  lora_in = lora_add
67
- url = f'https://huggingface.co/{lora_in}/tree/main'
68
- lora_name = scrape_lora_link(url)
69
- pipe.load_lora_weights(lora_in, weight_name=lora_name)
70
- pipe.fuse_lora(lora_scale=lora_scale)
71
 
72
- @spaces.GPU(duration=100)
73
  def generate_image(
74
  prompt:str,
 
75
  width:int=768,
76
  height:int=1024,
77
  scales:float=3.5,
78
  steps:int=24,
79
- seed:int=-1,
80
- nums:int=1):
81
 
82
  pipe.to(device="cuda")
83
 
@@ -88,26 +62,29 @@ def generate_image(
88
 
89
  text = str(translator.translate(prompt, 'English'))
90
 
91
- generator = torch.Generator().manual_seed(seed)
92
-
93
 
94
- image = pipe(
95
  prompt=text,
96
  height=height,
97
  width=width,
98
  guidance_scale=scales,
99
- output_type="pil",
100
  num_inference_steps=steps,
101
- max_sequence_length=512,
102
- num_images_per_prompt=nums,
103
- generator=generator,
104
- ).images
 
 
 
 
 
 
105
 
106
- return image, seed
107
 
108
  def gen(
109
  prompt:str,
110
- lora_scale:float=1.0,
111
  lora_in:str="",
112
  lora_add:str="",
113
  width:int=768,
@@ -115,11 +92,10 @@ def gen(
115
  scales:float=3.5,
116
  steps:int=24,
117
  seed:int=-1,
118
- nums:int=1,
119
  progress=gr.Progress(track_tqdm=True)
120
  ):
121
- enable_lora(lora_scale, lora_in, lora_add)
122
- return generate_image(prompt,width,height,scales,steps,seed,nums)
123
 
124
 
125
 
@@ -151,11 +127,13 @@ examples = [
151
  # Gradio Interface
152
 
153
  with gr.Blocks(css=CSS, js=JS, theme="Nymbo/Nymbo_Theme") as demo:
154
- gr.HTML("<h1><center>Flux Labs</center></h1>")
155
- gr.HTML("<p><center>Choose the LoRA model on the right menu</center></p>")
156
  with gr.Row():
157
  with gr.Column(scale=4):
158
- img = gr.Gallery(label='flux Generated Image', columns = 1, preview=True, height=600)
 
 
159
  with gr.Row():
160
  prompt = gr.Textbox(label='Enter Your Prompt (Multi-Languages)', placeholder="Enter prompt...", scale=6)
161
  sendBtn = gr.Button(scale=1, variant='primary')
@@ -196,20 +174,6 @@ with gr.Blocks(css=CSS, js=JS, theme="Nymbo/Nymbo_Theme") as demo:
196
  step=1,
197
  value=-1,
198
  )
199
- nums = gr.Slider(
200
- label="Image Numbers",
201
- minimum=1,
202
- maximum=4,
203
- step=1,
204
- value=1,
205
- )
206
- lora_scale = gr.Slider(
207
- label="LoRA Scale",
208
- minimum=0.1,
209
- maximum=1.0,
210
- step=0.1,
211
- value=1.0,
212
- )
213
  lora_in = gr.Dropdown(
214
  choices=["Shakker-Labs/FLUX.1-dev-LoRA-blended-realistic-illustration", "Shakker-Labs/AWPortrait-FL",""],
215
  label="LoRA Model",
@@ -223,8 +187,8 @@ with gr.Blocks(css=CSS, js=JS, theme="Nymbo/Nymbo_Theme") as demo:
223
  )
224
  gr.Examples(
225
  examples=examples,
226
- inputs=[prompt,lora_scale,lora_in],
227
- outputs=[img, seed],
228
  fn=gen,
229
  cache_examples="lazy",
230
  examples_per_page=4,
@@ -238,17 +202,15 @@ with gr.Blocks(css=CSS, js=JS, theme="Nymbo/Nymbo_Theme") as demo:
238
  fn=gen,
239
  inputs=[
240
  prompt,
241
- lora_scale,
242
  lora_in,
243
  lora_add,
244
  width,
245
  height,
246
  scales,
247
  steps,
248
- seed,
249
- nums
250
  ],
251
- outputs=[img, seed],
252
  api_name="run",
253
  )
254
 
 
1
  import spaces
2
  import os
3
  import gradio as gr
4
+ #import torch
5
  import numpy as np
6
  import random
7
+ #from diffusers import FluxPipeline
8
+ from huggingface_hub import AsyncInferenceClient
9
  from translatepy import Translator
10
+ #from huggingface_hub import hf_hub_download
11
  import requests
12
  import re
13
 
 
15
  translator = Translator()
16
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
  # Constants
18
+ basemodel = "black-forest-labs/FLUX.1-dev"
19
  MAX_SEED = np.iinfo(np.int32).max
20
 
21
  CSS = """
 
31
  }
32
  }"""
33
 
34
+ client = AsyncInferenceClient()
 
35
 
 
 
 
 
 
36
 
37
+ def enable_lora(lora_in, lora_add):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  pipe.unload_lora_weights()
39
  if not lora_in and not lora_add:
40
+ return basemodel
41
  else:
42
  if lora_add:
43
  lora_in = lora_add
44
+ return lora_in
 
 
 
45
 
46
+ @spaces.GPU()
47
  def generate_image(
48
  prompt:str,
49
+ model:str,
50
  width:int=768,
51
  height:int=1024,
52
  scales:float=3.5,
53
  steps:int=24,
54
+ seed:int=-1):
 
55
 
56
  pipe.to(device="cuda")
57
 
 
62
 
63
  text = str(translator.translate(prompt, 'English'))
64
 
65
+ #generator = torch.Generator().manual_seed(seed)
 
66
 
67
+ image1 = await client.text_to_image(
68
  prompt=text,
69
  height=height,
70
  width=width,
71
  guidance_scale=scales,
 
72
  num_inference_steps=steps,
73
+ model=basemodel,
74
+ )
75
+ image2 = await client.text_to_image(
76
+ prompt=text,
77
+ height=height,
78
+ width=width,
79
+ guidance_scale=scales,
80
+ num_inference_steps=steps,
81
+ model=model,
82
+ )
83
 
84
+ return image1, image2, seed
85
 
86
  def gen(
87
  prompt:str,
 
88
  lora_in:str="",
89
  lora_add:str="",
90
  width:int=768,
 
92
  scales:float=3.5,
93
  steps:int=24,
94
  seed:int=-1,
 
95
  progress=gr.Progress(track_tqdm=True)
96
  ):
97
+ model = enable_lora(lora_in, lora_add)
98
+ return generate_image(prompt,model,width,height,scales,steps,seed)
99
 
100
 
101
 
 
127
  # Gradio Interface
128
 
129
  with gr.Blocks(css=CSS, js=JS, theme="Nymbo/Nymbo_Theme") as demo:
130
+ gr.HTML("<h1><center>Flux Labs(vs LoRA)</center></h1>")
131
+ gr.HTML("<p><center>Choose the LoRA model on the menu</center></p>")
132
  with gr.Row():
133
  with gr.Column(scale=4):
134
+ with gr.Row():
135
+ img1 = gr.Gallery(label='flux Generated Image', columns = 1, preview=True, height=600)
136
+ img2 = gr.Gallery(label='LoRA Generated Image', columns = 1, preview=True, height=600)
137
  with gr.Row():
138
  prompt = gr.Textbox(label='Enter Your Prompt (Multi-Languages)', placeholder="Enter prompt...", scale=6)
139
  sendBtn = gr.Button(scale=1, variant='primary')
 
174
  step=1,
175
  value=-1,
176
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  lora_in = gr.Dropdown(
178
  choices=["Shakker-Labs/FLUX.1-dev-LoRA-blended-realistic-illustration", "Shakker-Labs/AWPortrait-FL",""],
179
  label="LoRA Model",
 
187
  )
188
  gr.Examples(
189
  examples=examples,
190
+ inputs=[prompt,lora_in],
191
+ outputs=[img1, img2, seed],
192
  fn=gen,
193
  cache_examples="lazy",
194
  examples_per_page=4,
 
202
  fn=gen,
203
  inputs=[
204
  prompt,
 
205
  lora_in,
206
  lora_add,
207
  width,
208
  height,
209
  scales,
210
  steps,
211
+ seed
 
212
  ],
213
+ outputs=[img1, img2, seed],
214
  api_name="run",
215
  )
216