from typing import Annotated from fastapi import FastAPI, Path, Query, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import torch from torch import autocast from diffusers import DiffusionPipeline from io import BytesIO import base64 from os.path import dirname # class Prompt(BaseModel): # prompt: str # steps: Annotated[int, Path(title="No of steps", ge=4, le=10)] = 8 # guide: Annotated[float, Path(title="Guidance scale", ge=0.5, le=2)] = 0.8 app = FastAPI() app.add_middleware( CORSMiddleware, allow_credentials=True, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] ) device = "cuda" pipe = DiffusionPipeline.from_pretrained(f'{dirname(__file__)}/cgt2im', # use_auth_token=auth_token, # use_safetensors=True ) pipe = pipe.to(device, dtype=torch.float16) # @app.get("/") # def generate(prompt: str): # with autocast(device): # image = pipe( # prompt=prompt, # num_inference_steps=8, # guidance_scale=8.0, # lcm_origin_steps=50, # output_type="pil", # ).images[0] # # image.save("testimage.png") # buffer = BytesIO() # image.save(buffer, format="PNG") # imgstr = base64.b64encode(buffer.getvalue()) # return Response(content=imgstr, media_type="image/png") @app.get("/t2i") def generate(prompt: str, steps: Annotated[int, Query(ge=4, le=10)] = 8, guide: Annotated[float, Query(ge=0.5, le=2)] = 0.8, ): with autocast(device): image = pipe( prompt=prompt, num_inference_steps=steps, guidance_scale=guide, lcm_origin_steps=50, output_type="pil", ).images[0] # image.save("testimage.png") buffer = BytesIO() image.save(buffer, format="PNG") imgstr = base64.b64encode(buffer.getvalue()) return Response(content=imgstr, media_type="image/png") @app.get("/", response_class=HTMLResponse) async def read_home(): with open("app/static/index.html", "r") as file: content = file.read() return HTMLResponse(content=content) # @app.post("/t2i") # def generate(prompt: Prompt): # with autocast(device): # image = pipe( # prompt=prompt.prompt, # num_inference_steps=prompt.steps, # guidance_scale=prompt.guide, # lcm_origin_steps=50, # output_type="pil", # ).images[0] # # image.save("testimage.png") # buffer = BytesIO() # image.save(buffer, format="PNG") # imgstr = base64.b64encode(buffer.getvalue()) # return Response(content=imgstr, media_type="image/png")