salomonsky commited on
Commit
6b3d1c3
1 Parent(s): 8c1a558

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -36
app.py CHANGED
@@ -7,21 +7,33 @@ import streamlit as st
7
  from huggingface_hub import InferenceClient, AsyncInferenceClient
8
  from gradio_client import Client, handle_file
9
  import asyncio
 
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  HF_TOKEN_UPSCALER = os.environ.get("HF_TOKEN_UPSCALER")
14
  client = AsyncInferenceClient()
 
15
  DATA_PATH = Path("./data")
16
  DATA_PATH.mkdir(exist_ok=True)
17
 
18
- async def generate_image(prompt, model, width, height, scales, steps, seed):
 
 
 
 
 
 
 
 
 
 
19
  try:
20
  if seed == -1:
21
  seed = random.randint(0, MAX_SEED)
22
  seed = int(seed)
23
  image = await client.text_to_image(
24
- prompt=prompt, height=height, width=width, guidance_scale=scales,
25
  num_inference_steps=steps, model=model
26
  )
27
  return image, seed
@@ -38,6 +50,63 @@ def get_upscale_finegrain(prompt, img_path, upscale_factor):
38
  except Exception as e:
39
  return None
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def save_image(image, seed):
42
  try:
43
  image_path = DATA_PATH / f"image_{seed}.jpg"
@@ -50,7 +119,8 @@ def save_image(image, seed):
50
  def get_storage():
51
  files = [file for file in DATA_PATH.glob("*.jpg") if file.is_file()]
52
  files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
53
- return [str(file.resolve()) for file in files]
 
54
 
55
  def get_prompts():
56
  prompt_files = [file for file in DATA_PATH.glob("*.txt") if file.is_file()]
@@ -68,18 +138,19 @@ def delete_image(image_path):
68
 
69
  def main():
70
  st.set_page_config(layout="wide")
71
- st.title("Generación de Imágenes")
72
 
73
- prompt = st.text_input("Descripción de la imagen", max_chars=200)
74
-
75
- with st.expander("Opciones avanzadas", expanded=False):
76
- basemodel = st.selectbox("Modelo Base", ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-DEV"])
77
- format_option = st.selectbox("Formato", ["9:16", "16:9"])
78
- process_upscale = st.checkbox("Procesar Escalador", value=True)
79
- upscale_factor = st.selectbox("Factor de Escala", [2, 4, 8], index=0)
80
- scales = st.slider("Escalado", 1, 20, 10)
81
- steps = st.slider("Pasos", 1, 100, 20)
82
- seed = st.number_input("Semilla", value=-1)
 
83
 
84
  if format_option == "9:16":
85
  width = 720
@@ -88,35 +159,48 @@ def main():
88
  width = 1280
89
  height = 720
90
 
91
- if st.button("Generar Imagen"):
92
- with st.spinner("Generando imagen..."):
93
- image, seed = await generate_image(prompt, basemodel, width, height, scales, steps, seed)
94
-
95
- if isinstance(image, str) and image.startswith("Error"):
96
- st.error(image)
 
 
 
 
 
97
  else:
98
- image_path = save_image(image, seed)
99
- if image_path:
100
- st.image(image_path, caption="Imagen Generada")
101
- st.success("Imagen generada y guardada.")
102
-
103
- # Mostrar galería de imágenes
104
- files = get_storage()
105
- prompts = get_prompts()
106
 
107
- st.subheader("Galería de Imágenes")
108
- cols = st.columns(3)
 
 
 
 
 
 
 
 
109
 
110
  for idx, file in enumerate(files):
111
- with cols[idx % 3]:
112
  image = Image.open(file)
113
- prompt_text = prompts.get(Path(file).stem.replace("image_", ""), "No disponible")
 
114
 
115
- st.image(image, caption=f"Imagen {idx + 1}")
116
  st.write(f"Prompt: {prompt_text}")
117
 
118
- if st.button(f"Borrar Imagen {idx + 1}", key=f"delete_{idx}"):
119
- delete_image(file)
 
 
 
 
 
 
120
 
121
  if __name__ == "__main__":
122
- main()
 
7
  from huggingface_hub import InferenceClient, AsyncInferenceClient
8
  from gradio_client import Client, handle_file
9
  import asyncio
10
+ from concurrent.futures import ThreadPoolExecutor
11
 
12
  MAX_SEED = np.iinfo(np.int32).max
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
  HF_TOKEN_UPSCALER = os.environ.get("HF_TOKEN_UPSCALER")
15
  client = AsyncInferenceClient()
16
+ llm_client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
17
  DATA_PATH = Path("./data")
18
  DATA_PATH.mkdir(exist_ok=True)
19
 
20
+ def run_async(func):
21
+ loop = asyncio.new_event_loop()
22
+ asyncio.set_event_loop(loop)
23
+ executor = ThreadPoolExecutor(max_workers=1)
24
+ result = loop.run_in_executor(executor, func)
25
+ return loop.run_until_complete(result)
26
+
27
+ def enable_lora(lora_add, basemodel):
28
+ return lora_add if lora_add else basemodel
29
+
30
+ async def generate_image(combined_prompt, model, width, height, scales, steps, seed):
31
  try:
32
  if seed == -1:
33
  seed = random.randint(0, MAX_SEED)
34
  seed = int(seed)
35
  image = await client.text_to_image(
36
+ prompt=combined_prompt, height=height, width=width, guidance_scale=scales,
37
  num_inference_steps=steps, model=model
38
  )
39
  return image, seed
 
50
  except Exception as e:
51
  return None
52
 
53
+ def save_prompt(prompt_text, seed):
54
+ try:
55
+ prompt_file_path = DATA_PATH / f"prompt_{seed}.txt"
56
+ with open(prompt_file_path, "w") as prompt_file:
57
+ prompt_file.write(prompt_text)
58
+ return prompt_file_path
59
+ except Exception as e:
60
+ st.error(f"Error al guardar el prompt: {e}")
61
+ return None
62
+
63
+ async def gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, lora_model, process_lora, process_enhancer):
64
+ model = enable_lora(lora_model, basemodel) if process_lora else basemodel
65
+ combined_prompt = prompt
66
+
67
+ if process_enhancer:
68
+ improved_prompt = await improve_prompt(prompt)
69
+ combined_prompt = f"{prompt} {improved_prompt}"
70
+
71
+ if seed == -1:
72
+ seed = random.randint(0, MAX_SEED)
73
+ seed = int(seed)
74
+ progress_bar = st.progress(0)
75
+ image, seed = await generate_image(combined_prompt, model, width, height, scales, steps, seed)
76
+ progress_bar.progress(50)
77
+
78
+ if isinstance(image, str) and image.startswith("Error"):
79
+ progress_bar.empty()
80
+ return [image, None, combined_prompt]
81
+
82
+ image_path = save_image(image, seed)
83
+ prompt_file_path = save_prompt(combined_prompt, seed)
84
+
85
+ if process_upscale:
86
+ upscale_image_path = get_upscale_finegrain(combined_prompt, image_path, upscale_factor)
87
+ if upscale_image_path:
88
+ upscale_image = Image.open(upscale_image_path)
89
+ upscale_image.save(DATA_PATH / f"upscale_image_{seed}.jpg", format="JPEG")
90
+ progress_bar.progress(100)
91
+ image_path.unlink()
92
+ return [str(DATA_PATH / f"upscale_image_{seed}.jpg"), str(prompt_file_path)]
93
+ else:
94
+ progress_bar.empty()
95
+ return [str(image_path), str(prompt_file_path)]
96
+ else:
97
+ progress_bar.progress(100)
98
+ return [str(image_path), str(prompt_file_path)]
99
+
100
+ async def improve_prompt(prompt):
101
+ try:
102
+ instruction = ("With this idea, describe in English a detailed txt2img prompt in 500 characters at most, add ilumination, admosphere, cinematic and characters...")
103
+ formatted_prompt = f"{prompt}: {instruction}"
104
+ response = llm_client.text_generation(formatted_prompt, max_new_tokens=300)
105
+ improved_text = response['generated_text'].strip() if 'generated_text' in response else response.strip()
106
+ return improved_text[:300] if len(improved_text) > 300 else improved_text
107
+ except Exception as e:
108
+ return f"Error mejorando el prompt: {e}"
109
+
110
  def save_image(image, seed):
111
  try:
112
  image_path = DATA_PATH / f"image_{seed}.jpg"
 
119
  def get_storage():
120
  files = [file for file in DATA_PATH.glob("*.jpg") if file.is_file()]
121
  files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
122
+ usage = sum([file.stat().st_size for file in files])
123
+ return [str(file.resolve()) for file in files], f"Uso total: {usage/(1024.0 ** 3):.3f}GB"
124
 
125
  def get_prompts():
126
  prompt_files = [file for file in DATA_PATH.glob("*.txt") if file.is_file()]
 
138
 
139
  def main():
140
  st.set_page_config(layout="wide")
141
+ st.title("FLUX with prompt enhancer and upscaler with LORA model training")
142
 
143
+ prompt = st.sidebar.text_input("Descripción de la imagen", max_chars=200)
144
+ process_enhancer = st.sidebar.checkbox("Mejorar Prompt", value=True) # Nuevo checkbox
145
+ basemodel = st.sidebar.selectbox("Modelo Base", ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-DEV"])
146
+ lora_model = st.sidebar.selectbox("LORA Realismo", ["Shakker-Labs/FLUX.1-dev-LoRA-add-details", "XLabs-AI/flux-RealismLora"])
147
+ format_option = st.sidebar.selectbox("Formato", ["9:16", "16:9"])
148
+ process_lora = st.sidebar.checkbox("Procesar LORA", value=True)
149
+ process_upscale = st.sidebar.checkbox("Procesar Escalador", value=True)
150
+ upscale_factor = st.sidebar.selectbox("Factor de Escala", [2, 4, 8], index=0)
151
+ scales = st.sidebar.slider("Escalado", 1, 20, 10)
152
+ steps = st.sidebar.slider("Pasos", 1, 100, 20)
153
+ seed = st.sidebar.number_input("Semilla", value=-1)
154
 
155
  if format_option == "9:16":
156
  width = 720
 
159
  width = 1280
160
  height = 720
161
 
162
+ if st.sidebar.button("Generar Imagen"):
163
+ with st.spinner("Mejorando y generando imagen..."):
164
+ result = asyncio.run(gen(prompt, basemodel, width, height, scales, steps, seed, upscale_factor, process_upscale, lora_model, process_lora, process_enhancer))
165
+ image_paths = result[0]
166
+ prompt_file = result[1]
167
+
168
+ st.write(f"Image paths: {image_paths}")
169
+
170
+ if image_paths:
171
+ if Path(image_paths).exists():
172
+ st.image(image_paths, caption="Imagen Generada")
173
  else:
174
+ st.error("El archivo de imagen no existe.")
 
 
 
 
 
 
 
175
 
176
+ if prompt_file and Path(prompt_file).exists():
177
+ prompt_text = Path(prompt_file).read_text()
178
+ st.write(f"Prompt utilizado: {prompt_text}")
179
+ else:
180
+ st.write("El archivo del prompt no está disponible.")
181
+
182
+ files, usage = get_storage()
183
+ st.text(usage)
184
+ cols = st.columns(6)
185
+ prompts = get_prompts()
186
 
187
  for idx, file in enumerate(files):
188
+ with cols[idx % 6]:
189
  image = Image.open(file)
190
+ prompt_file = prompts.get(Path(file).stem.replace("image_", ""), None)
191
+ prompt_text = Path(prompt_file).read_text() if prompt_file else "No disponible"
192
 
193
+ st.image(image, caption=f"Imagen {idx+1}")
194
  st.write(f"Prompt: {prompt_text}")
195
 
196
+ if st.button(f"Borrar Imagen {idx+1}", key=f"delete_{idx}"):
197
+ try:
198
+ os.remove(file)
199
+ if prompt_file:
200
+ os.remove(prompt_file)
201
+ st.success(f"Imagen {idx+1} y su prompt fueron borrados.")
202
+ except Exception as e:
203
+ st.error(f"Error al borrar la imagen o prompt: {e}")
204
 
205
  if __name__ == "__main__":
206
+ main()