Spaces:
Paused
Paused
from typing import Annotated | |
from fastapi import FastAPI, Path, Query, Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import FileResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
import torch | |
from torch import autocast | |
from diffusers import DiffusionPipeline | |
from io import BytesIO | |
import base64 | |
from os.path import dirname | |
# class Prompt(BaseModel): | |
# prompt: str | |
# steps: Annotated[int, Path(title="No of steps", ge=4, le=10)] = 8 | |
# guide: Annotated[float, Path(title="Guidance scale", ge=0.5, le=2)] = 0.8 | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_credentials=True, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"] | |
) | |
device = "cuda" | |
pipe = DiffusionPipeline.from_pretrained(f'{dirname(__file__)}/cgt2im', | |
# use_auth_token=auth_token, | |
# use_safetensors=True | |
) | |
pipe = pipe.to(device, dtype=torch.float16) | |
# @app.get("/") | |
# def generate(prompt: str): | |
# with autocast(device): | |
# image = pipe( | |
# prompt=prompt, | |
# num_inference_steps=8, | |
# guidance_scale=8.0, | |
# lcm_origin_steps=50, | |
# output_type="pil", | |
# ).images[0] | |
# # image.save("testimage.png") | |
# buffer = BytesIO() | |
# image.save(buffer, format="PNG") | |
# imgstr = base64.b64encode(buffer.getvalue()) | |
# return Response(content=imgstr, media_type="image/png") | |
def generate(prompt: str, | |
steps: Annotated[int, Query(ge=4, le=10)] = 8, | |
guide: Annotated[float, Query(ge=0.5, le=2)] = 0.8, | |
): | |
with autocast(device): | |
image = pipe( | |
prompt=prompt, | |
num_inference_steps=steps, | |
guidance_scale=guide, | |
lcm_origin_steps=50, | |
output_type="pil", | |
).images[0] | |
# image.save("testimage.png") | |
buffer = BytesIO() | |
image.save(buffer, format="PNG") | |
imgstr = base64.b64encode(buffer.getvalue()) | |
return Response(content=imgstr, media_type="image/png") | |
async def read_home(): | |
with open("app/static/index.html", "r") as file: | |
content = file.read() | |
return HTMLResponse(content=content) | |
# @app.post("/t2i") | |
# def generate(prompt: Prompt): | |
# with autocast(device): | |
# image = pipe( | |
# prompt=prompt.prompt, | |
# num_inference_steps=prompt.steps, | |
# guidance_scale=prompt.guide, | |
# lcm_origin_steps=50, | |
# output_type="pil", | |
# ).images[0] | |
# # image.save("testimage.png") | |
# buffer = BytesIO() | |
# image.save(buffer, format="PNG") | |
# imgstr = base64.b64encode(buffer.getvalue()) | |
# return Response(content=imgstr, media_type="image/png") |