File size: 3,933 Bytes
dddb041 b8a701a 7ce51a2 b8a701a dddb041 223f8b4 49f145b a146eda ccde57b a146eda ccde57b b8a701a 0347515 b8a701a 0347515 a146eda 49f145b 0347515 7ce51a2 9f29ad9 baeab79 7ce51a2 322db57 7ce51a2 322db57 7ce51a2 db08793 a146eda 322db57 bf71179 7ce51a2 9f29ad9 538d554 057bc07 a200bb2 1b088a5 057bc07 c133494 057bc07 c49f67b a146eda 057bc07 7ce51a2 a200bb2 057bc07 c49f67b 057bc07 c03b3ba 7ce51a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
import requests
import io
from PIL import Image
import json
import os
import logging
import math
from tqdm import tqdm
import time
logging.basicConfig(level=logging.DEBUG)
with open('loras.json', 'r') as f:
loras = json.load(f)
def update_selection(selected_state: gr.SelectData):
logging.debug(f"Inside update_selection, selected_state: {selected_state}")
selected_lora_index = selected_state.index # Changed this line
selected_lora = loras[selected_lora_index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
return (
gr.update(placeholder=new_placeholder),
selected_state
)
def run_lora(prompt, selected_state, progress=gr.Progress(track_tqdm=True)):
logging.debug(f"Inside run_lora, selected_state: {selected_state}")
if not selected_state:
logging.error("selected_state is None or empty.")
raise gr.Error("You must select a LoRA")
selected_lora_index = selected_state.index # Changed this line
selected_lora = loras[selected_lora_index]
api_url = f"https://api-inference.huggingface.co/models/{selected_lora['repo']}"
trigger_word = selected_lora["trigger_word"]
#token = os.getenv("API_TOKEN")
payload = {
"inputs": f"{prompt} {trigger_word}",
"parameters":{"negative_prompt": "bad art, ugly, watermark, deformed"},
}
#headers = {"Authorization": f"Bearer {token}"}
# Add a print statement to display the API request
print(f"API Request: {api_url}")
#print(f"API Headers: {headers}")
print(f"API Payload: {payload}")
error_count = 0
pbar = tqdm(total=None, desc="Loading model")
while(True):
response = requests.post(api_url, json=payload)
if response.status_code == 200:
return Image.open(io.BytesIO(response.content))
elif response.status_code == 503:
#503 is triggered when the model is doing cold boot. It also gives you a time estimate from when the model is loaded but it is not super precise
time.sleep(1)
pbar.update(1)
elif response.status_code == 500 and error_count < 5:
print(response.content)
time.sleep(1)
error_count += 1
continue
else:
logging.error(f"API Error: {response.status_code}")
raise gr.Error("API Error: Unable to fetch the image.") # Raise a Gradio error here
with gr.Blocks(css="custom.css") as app:
title = gr.Markdown("# artificialguybr LoRA portfolio")
description = gr.Markdown( # Add this line
"### This is my portfolio. Follow me on Twitter [@artificialguybr](https://twitter.com/artificialguybr). "
"Note: The speed and generation quality are for demonstration purposes. "
"For best quality, use Auto or Comfy. Special thanks to Hugging Face for their free inference API."
)
selected_state = gr.State()
with gr.Row():
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Gallery",
allow_preview=False,
columns=3
)
with gr.Column():
prompt_title = gr.Markdown("### Click on a LoRA in the gallery to select it")
with gr.Row():
prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, placeholder="Type a prompt after selecting a LoRA")
button = gr.Button("Run")
result = gr.Image(interactive=False, label="Generated Image")
gallery.select(
update_selection,
outputs=[prompt, selected_state]
)
prompt.submit(
fn=run_lora,
inputs=[prompt, selected_state],
outputs=[result]
)
button.click(
fn=run_lora,
inputs=[prompt, selected_state],
outputs=[result]
)
app.queue(max_size=20, concurrency_count=5)
app.launch() |