from fastapi import FastAPI, UploadFile from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse import subprocess import os import json import uuid import logging import torch from diffusers import ( StableDiffusionPipeline, DPMSolverMultistepScheduler, EulerDiscreteScheduler, ) app = FastAPI() def file_extension(filename): filename_list = filename.split(".") return filename_list[1] @app.get("/generate") def generate_image(prompt, model): torch.cuda.empty_cache() modelArray = model.split(",") modelName = modelArray[0] modelVersion = modelArray[1] pipeline = StableDiffusionPipeline.from_pretrained( str(modelName), torch_dtype=torch.float16 ) pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to("cuda") image = pipeline(prompt, num_inference_steps=50, height=512, width=512).images[0] filename = str(uuid.uuid4()) + ".jpg" image.save(filename) assertion = { "assertions": [ { "label": "com.truepic.custom.ai", "data": { "model_name": modelName, "model_version": modelVersion, "prompt": prompt, }, } ] } json_object = json.dumps(assertion) subprocess.check_output( [ "./scripts/sign.sh", filename, "--assertions-inline", json_object ] ) subprocess.check_output( [ "cp", "output.jpg", "static/" + filename, ] ) return {"response": filename} @app.post("/verify") def verify_image(fileUpload: UploadFile): logging.warning("in verify") logging.warning(fileUpload.filename) # check if the file has been uploaded if fileUpload.filename: # strip the leading path from the file name fn = os.path.basename(fileUpload.filename) # open read and write the file into the server open(fn, "wb").write(fileUpload.file.read()) response = subprocess.check_output( [ "./scripts/verify.sh", fileUpload.filename, ] ) logging.warning(response) response_list = response.splitlines() c2pa_string = str(response_list[0]) c2pa = c2pa_string.split(":", 1) c2pa = c2pa[1].strip(" ").strip("'") watermark_string = str(response_list[1]) watermark = watermark_string.split(":", 1) watermark = watermark[1].strip(" ").strip("'") original_media_string = str(response_list[2]) original_media = original_media_string.split(":", 1) original_media = original_media[1].strip(" ").strip("'") if original_media != 'n/a': original_media_extension = file_extension(original_media) logging.warning(original_media_extension) filename = str(uuid.uuid4()) + original_media_extension response = subprocess.check_output( [ "cp", original_media, "static/" + filename, ] ) original_media = filename return {"response": fileUpload.filename, "contains_c2pa" : c2pa, "contains_watermark" : watermark, "original_media" : original_media} app.mount("/", StaticFiles(directory="static", html=True), name="static") @app.get("/") def index() -> FileResponse: return FileResponse(path="/app/static/index.html", media_type="text/html")