cgt2ids7 / app /cloudgate.py
khawir's picture
Add application file
0b8378a
raw
history blame
3.1 kB
from typing import Annotated
from fastapi import FastAPI, Path, Query, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import torch
from torch import autocast
from diffusers import DiffusionPipeline
from io import BytesIO
import base64
from os.path import dirname
# class Prompt(BaseModel):
# prompt: str
# steps: Annotated[int, Path(title="No of steps", ge=4, le=10)] = 8
# guide: Annotated[float, Path(title="Guidance scale", ge=0.5, le=2)] = 0.8
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"]
)
device = "cuda"
pipe = DiffusionPipeline.from_pretrained(f'{dirname(__file__)}/cgt2im',
# use_auth_token=auth_token,
# use_safetensors=True
)
pipe = pipe.to(device, dtype=torch.float16)
# @app.get("/")
# def generate(prompt: str):
# with autocast(device):
# image = pipe(
# prompt=prompt,
# num_inference_steps=8,
# guidance_scale=8.0,
# lcm_origin_steps=50,
# output_type="pil",
# ).images[0]
# # image.save("testimage.png")
# buffer = BytesIO()
# image.save(buffer, format="PNG")
# imgstr = base64.b64encode(buffer.getvalue())
# return Response(content=imgstr, media_type="image/png")
@app.get("/t2i")
def generate(prompt: str,
steps: Annotated[int, Query(ge=4, le=10)] = 8,
guide: Annotated[float, Query(ge=0.5, le=2)] = 0.8,
):
with autocast(device):
image = pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=guide,
lcm_origin_steps=50,
output_type="pil",
).images[0]
# image.save("testimage.png")
buffer = BytesIO()
image.save(buffer, format="PNG")
imgstr = base64.b64encode(buffer.getvalue())
return Response(content=imgstr, media_type="image/png")
@app.get("/", response_class=HTMLResponse)
async def read_home():
with open("app/static/index.html", "r") as file:
content = file.read()
return HTMLResponse(content=content)
# @app.post("/t2i")
# def generate(prompt: Prompt):
# with autocast(device):
# image = pipe(
# prompt=prompt.prompt,
# num_inference_steps=prompt.steps,
# guidance_scale=prompt.guide,
# lcm_origin_steps=50,
# output_type="pil",
# ).images[0]
# # image.save("testimage.png")
# buffer = BytesIO()
# image.save(buffer, format="PNG")
# imgstr = base64.b64encode(buffer.getvalue())
# return Response(content=imgstr, media_type="image/png")