jclyo1's picture
updates
da55d71
raw
history blame
4.1 kB
from fastapi import FastAPI, UploadFile
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import subprocess
import os
import json
import uuid
import logging
import torch
from diffusers import (
StableDiffusionPipeline,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
app = FastAPI()
def file_extension(filename):
filename_list = filename.split(".")
return filename_list[1]
@app.get("/generate")
def generate_image(prompt, model):
torch.cuda.empty_cache()
modelArray = model.split(",")
modelName = modelArray[0]
modelVersion = modelArray[1]
pipeline = StableDiffusionPipeline.from_pretrained(
str(modelName), torch_dtype=torch.float16
)
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to("cuda")
image = pipeline(prompt, num_inference_steps=50, height=512, width=512).images[0]
filename = str(uuid.uuid4()) + ".jpg"
image.save(filename)
assertion = {
"assertions": [
{
"label": "com.truepic.custom.ai",
"data": {
"model_name": modelName,
"model_version": modelVersion,
"prompt": prompt,
},
}
]
}
json_object = json.dumps(assertion)
subprocess.check_output(
[
"./scripts/sign.sh",
filename,
"--assertions-inline",
json_object
]
)
subprocess.check_output(
[
"cp",
"output.jpg",
"static/" + filename,
]
)
return {"response": filename}
@app.post("/verify")
def verify_image(fileUpload: UploadFile):
logging.warning("in verify")
logging.warning(fileUpload.filename)
# check if the file has been uploaded
if fileUpload.filename:
# strip the leading path from the file name
fn = os.path.basename(fileUpload.filename)
# open read and write the file into the server
open(fn, "wb").write(fileUpload.file.read())
response = subprocess.check_output(
[
"./scripts/verify.sh",
fileUpload.filename,
]
)
logging.warning(response)
response_list = response.splitlines()
c2pa_string = str(response_list[0])
c2pa = c2pa_string.split(":", 1)
c2pa = c2pa[1].strip(" ").strip("'")
watermark_string = str(response_list[1])
watermark = watermark_string.split(":", 1)
watermark = watermark[1].strip(" ").strip("'")
original_media_string = str(response_list[2])
original_media = original_media_string.split(":", 1)
original_media = original_media[1].strip(" ").strip("'")
if c2pa == 'true':
fileupload_extension = file_extension(fileUpload.filename)
filename = str(uuid.uuid4()) + "." + fileupload_extension
response = subprocess.check_output(
[
"cp",
fileUpload.filename,
"static/" + filename,
]
)
result_media = filename
elif original_media != 'n/a':
original_media_extension = file_extension(original_media)
filename = str(uuid.uuid4()) + "." + original_media_extension
response = subprocess.check_output(
[
"cp",
original_media,
"static/" + filename,
]
)
result_media = filename
else:
result_media = 'n/a'
return {"response": fileUpload.filename, "contains_c2pa" : c2pa, "contains_watermark" : watermark, "result_media" : result_media}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")