|
from fastapi import FastAPI, UploadFile |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import FileResponse |
|
|
|
import subprocess |
|
import os |
|
import json |
|
import uuid |
|
import logging |
|
|
|
import torch |
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
DPMSolverMultistepScheduler, |
|
EulerDiscreteScheduler, |
|
) |
|
|
|
app = FastAPI() |
|
|
|
def file_extension(filename): |
|
filename_list = filename.split(".") |
|
return filename_list[1] |
|
|
|
|
|
@app.get("/generate") |
|
def generate_image(prompt, model): |
|
torch.cuda.empty_cache() |
|
|
|
modelArray = model.split(",") |
|
modelName = modelArray[0] |
|
modelVersion = modelArray[1] |
|
|
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
str(modelName), torch_dtype=torch.float16 |
|
) |
|
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) |
|
pipeline = pipeline.to("cuda") |
|
|
|
image = pipeline(prompt, num_inference_steps=50, height=512, width=512).images[0] |
|
|
|
filename = str(uuid.uuid4()) + ".jpg" |
|
image.save(filename) |
|
|
|
assertion = { |
|
"assertions": [ |
|
{ |
|
"label": "com.truepic.custom.ai", |
|
"data": { |
|
"model_name": modelName, |
|
"model_version": modelVersion, |
|
"prompt": prompt, |
|
}, |
|
} |
|
] |
|
} |
|
|
|
json_object = json.dumps(assertion) |
|
|
|
subprocess.check_output( |
|
[ |
|
"./scripts/sign.sh", |
|
filename, |
|
"--assertions-inline", |
|
json_object |
|
] |
|
) |
|
|
|
subprocess.check_output( |
|
[ |
|
"cp", |
|
"output.jpg", |
|
"static/" + filename, |
|
] |
|
) |
|
|
|
return {"response": filename} |
|
|
|
|
|
@app.post("/verify") |
|
def verify_image(fileUpload: UploadFile): |
|
logging.warning("in verify") |
|
logging.warning(fileUpload.filename) |
|
|
|
|
|
|
|
if fileUpload.filename: |
|
|
|
fn = os.path.basename(fileUpload.filename) |
|
|
|
|
|
open(fn, "wb").write(fileUpload.file.read()) |
|
|
|
response = subprocess.check_output( |
|
[ |
|
"./scripts/verify.sh", |
|
fileUpload.filename, |
|
] |
|
) |
|
|
|
|
|
logging.warning(response) |
|
response_list = response.splitlines() |
|
|
|
c2pa_string = str(response_list[0]) |
|
c2pa = c2pa_string.split(":", 1) |
|
c2pa = c2pa[1].strip(" ").strip("'") |
|
|
|
watermark_string = str(response_list[1]) |
|
watermark = watermark_string.split(":", 1) |
|
watermark = watermark[1].strip(" ").strip("'") |
|
|
|
original_media_string = str(response_list[2]) |
|
original_media = original_media_string.split(":", 1) |
|
original_media = original_media[1].strip(" ").strip("'") |
|
|
|
if original_media != 'n/a': |
|
original_media_extension = file_extension(original_media) |
|
logging.warning(original_media_extension) |
|
|
|
filename = str(uuid.uuid4()) + original_media_extension |
|
|
|
response = subprocess.check_output( |
|
[ |
|
"cp", |
|
original_media, |
|
"static/" + filename, |
|
] |
|
) |
|
|
|
original_media = filename |
|
|
|
|
|
return {"response": fileUpload.filename, "contains_c2pa" : c2pa, "contains_watermark" : watermark, "original_media" : original_media} |
|
|
|
|
|
|
|
app.mount("/", StaticFiles(directory="static", html=True), name="static") |
|
|
|
|
|
@app.get("/") |
|
def index() -> FileResponse: |
|
return FileResponse(path="/app/static/index.html", media_type="text/html") |
|
|
|
|
|
|
|
|