UrangDiffusion / hft2is.py
John6666's picture
Upload 4 files
a009ad0 verified
import gradio as gr
import asyncio
from threading import RLock
from pathlib import Path
import os
from typing import Union
HF_TOKEN = os.getenv("HF_TOKEN", None)
server_timeout = 600
inference_timeout = 600
lock = RLock()
loaded_models = {}
def rename_image(image_path: Union[str, None], model_name: str, save_path: Union[str, None] = None):
import shutil
from datetime import datetime, timezone, timedelta
if image_path is None: return None
dt_now = datetime.now(timezone(timedelta(hours=9)))
filename = f"{model_name.split('/')[-1]}_{dt_now.strftime('%Y%m%d_%H%M%S')}.png"
try:
if Path(image_path).exists():
png_path = "image.png"
if str(Path(image_path).resolve()) != str(Path(png_path).resolve()): shutil.copy(image_path, png_path)
if save_path is not None:
new_path = str(Path(png_path).resolve().rename(Path(save_path).resolve()))
else:
new_path = str(Path(png_path).resolve().rename(Path(filename).resolve()))
return new_path
else:
return None
except Exception as e:
print(e)
return None
# https://github.com/gradio-app/gradio/blob/main/gradio/external.py
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
from typing import Literal
def load_from_model(model_name: str, hf_token: Union[str, Literal[False], None] = None):
import httpx
import huggingface_hub
from gradio.exceptions import ModelNotFoundError, TooManyRequestsError
model_url = f"https://huggingface.co/{model_name}"
api_url = f"https://api-inference.huggingface.co/models/{model_name}"
print(f"Fetching model from: {model_url}")
headers = ({} if hf_token in [False, None] else {"Authorization": f"Bearer {hf_token}"})
response = httpx.request("GET", api_url, headers=headers)
if response.status_code != 200:
raise ModelNotFoundError(
f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
)
p = response.json().get("pipeline_tag")
if p != "text-to-image": raise ModelNotFoundError(f"This model isn't for text-to-image or unsupported: {model_name}.")
headers["X-Wait-For-Model"] = "true"
kwargs = {}
if hf_token is not None: kwargs["token"] = hf_token
client = huggingface_hub.InferenceClient(model=model_name, headers=headers, timeout=server_timeout, **kwargs)
inputs = gr.components.Textbox(label="Input")
outputs = gr.components.Image(label="Output")
fn = client.text_to_image
def query_huggingface_inference_endpoints(*data, **kwargs):
try:
data = fn(*data, **kwargs) # type: ignore
except huggingface_hub.utils.HfHubHTTPError as e:
print(e)
if "429" in str(e): raise TooManyRequestsError() from e
except Exception as e:
print(e)
raise Exception() from e
return data
interface_info = {
"fn": query_huggingface_inference_endpoints,
"inputs": inputs,
"outputs": outputs,
"title": model_name,
}
return gr.Interface(**interface_info)
def load_model(model_name: str):
global loaded_models
global model_info_dict
if model_name in loaded_models.keys(): return loaded_models[model_name]
try:
loaded_models[model_name] = load_from_model(model_name, hf_token=HF_TOKEN)
print(f"Loaded: {model_name}")
except Exception as e:
if model_name in loaded_models.keys(): del loaded_models[model_name]
print(f"Failed to load: {model_name}")
print(e)
return None
return loaded_models[model_name]
def load_models(models: list):
for model in models:
load_model(model)
def warm_model(model_name: str):
model = load_model(model_name)
if model:
try:
print(f"Warming model: {model_name}")
infer_body(model, model_name, " ")
except Exception as e:
print(e)
def warm_models(models: list[str]):
for model in models:
asyncio.new_event_loop().run_in_executor(None, warm_model, model)
# https://huggingface.co/docs/api-inference/detailed_parameters
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
def infer_body(client: Union[gr.Interface, object], model_str: str, prompt: str, neg_prompt: str = "",
height: int = 0, width: int = 0, steps: int = 0, cfg: int = 0, seed: int = -1):
png_path = "image.png"
kwargs = {}
if height > 0: kwargs["height"] = height
if width > 0: kwargs["width"] = width
if steps > 0: kwargs["num_inference_steps"] = steps
if cfg > 0: cfg = kwargs["guidance_scale"] = cfg
if seed == -1: kwargs["seed"] = randomize_seed()
else: kwargs["seed"] = seed
try:
if isinstance(client, gr.Interface): image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
else: return None
if isinstance(image, tuple): return None
return save_image(image, png_path, model_str, prompt, neg_prompt, height, width, steps, cfg, seed)
except Exception as e:
print(e)
raise Exception(e) from e
async def infer(model_name: str, prompt: str, neg_prompt: str ="", height: int = 0, width: int = 0,
steps: int = 0, cfg: int = 0, seed: int = -1,
save_path: str | None = None, timeout: float = inference_timeout):
model = load_model(model_name)
if not model: return None
task = asyncio.create_task(asyncio.to_thread(infer_body, model, model_name, prompt, neg_prompt,
height, width, steps, cfg, seed))
await asyncio.sleep(0)
try:
result = await asyncio.wait_for(task, timeout=timeout)
except asyncio.TimeoutError as e:
print(e)
print(f"Task timed out: {model_name}")
if not task.done(): task.cancel()
result = None
raise Exception(f"Task timed out: {model_name}") from e
except Exception as e:
print(e)
if not task.done(): task.cancel()
result = None
raise Exception(e) from e
if task.done() and result is not None:
with lock:
image = rename_image(result, model_name, save_path)
return image
return None
def save_image(image, savefile, modelname, prompt, nprompt, height=0, width=0, steps=0, cfg=0, seed=-1):
from PIL import Image, PngImagePlugin
import json
try:
metadata = {"prompt": prompt, "negative_prompt": nprompt, "Model": {"Model": modelname.split("/")[-1]}}
if steps > 0: metadata["num_inference_steps"] = steps
if cfg > 0: metadata["guidance_scale"] = cfg
if seed != -1: metadata["seed"] = seed
if width > 0 and height > 0: metadata["resolution"] = f"{width} x {height}"
metadata_str = json.dumps(metadata)
info = PngImagePlugin.PngInfo()
info.add_text("metadata", metadata_str)
image.save(savefile, "PNG", pnginfo=info)
return str(Path(savefile).resolve())
except Exception as e:
print(f"Failed to save image file: {e}")
raise Exception(f"Failed to save image file:") from e
def randomize_seed():
from random import seed, randint
MAX_SEED = 2**32-1
seed()
rseed = randint(0, MAX_SEED)
return rseed
def gen_image(model_name: str, prompt: str, neg_prompt: str = "", height: int = 0, width: int = 0,
steps: int = 0, cfg: int = 0, seed: int = -1):
if model_name in ["NA", ""]: return gr.update()
try:
loop = asyncio.get_running_loop()
except Exception:
loop = asyncio.new_event_loop()
try:
result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, height, width,
steps, cfg, seed, None, inference_timeout))
except (Exception, asyncio.CancelledError) as e:
print(e)
print(f"Task aborted: {model_name}, Error: {e}")
result = None
raise gr.Error(f"Task aborted: {model_name}, Error: {e}")
finally:
loop.close()
return result
def generate_image_hf(model_name: str, prompt: str, negative_prompt: str, use_defaults: bool, resolution: str,
guidance_scale: float, num_inference_steps: int, seed: int, randomize_seed: bool, progress=gr.Progress()):
if randomize_seed: seed = -1
if use_defaults:
prompt = f"{prompt}, best quality, amazing quality, very aesthetic"
negative_prompt = f"nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], {negative_prompt}"
width, height = map(int, resolution.split('x'))
image = gen_image(model_name, prompt, negative_prompt, height, width, num_inference_steps, guidance_scale)
metadata_text = f"{prompt}\nNegative prompt: {negative_prompt}\nSteps: {num_inference_steps}, Sampler: Euler a, Size: {width}x{height}, Seed: {seed}, CFG scale: {guidance_scale}"
return image, seed, metadata_text