Spaces:
Running
on
T4
Running
on
T4
from fastapi import FastAPI, UploadFile | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
import subprocess | |
import os | |
import json | |
import uuid | |
import logging | |
import torch | |
from diffusers import ( | |
StableDiffusionPipeline, | |
DPMSolverMultistepScheduler, | |
EulerDiscreteScheduler, | |
) | |
app = FastAPI() | |
def file_extension(filename): | |
filename_list = filename.split(".") | |
return filename_list[1] | |
def generate_image(prompt, model): | |
torch.cuda.empty_cache() | |
modelArray = model.split(",") | |
modelName = modelArray[0] | |
modelVersion = modelArray[1] | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
str(modelName), torch_dtype=torch.float16 | |
) | |
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) | |
pipeline = pipeline.to("cuda") | |
image = pipeline(prompt, num_inference_steps=50, height=512, width=512).images[0] | |
filename = str(uuid.uuid4()) + ".jpg" | |
image.save(filename) | |
assertion = { | |
"assertions": [ | |
{ | |
"label": "com.truepic.custom.ai", | |
"data": { | |
"model_name": modelName, | |
"model_version": modelVersion, | |
"prompt": prompt, | |
}, | |
} | |
] | |
} | |
json_object = json.dumps(assertion) | |
subprocess.check_output( | |
[ | |
"./scripts/sign.sh", | |
filename, | |
"--assertions-inline", | |
json_object | |
] | |
) | |
subprocess.check_output( | |
[ | |
"cp", | |
"output.jpg", | |
"static/" + filename, | |
] | |
) | |
return {"response": filename} | |
def verify_image(fileUpload: UploadFile): | |
logging.warning("in verify") | |
logging.warning(fileUpload.filename) | |
# check if the file has been uploaded | |
if fileUpload.filename: | |
# strip the leading path from the file name | |
fn = os.path.basename(fileUpload.filename) | |
# open read and write the file into the server | |
open(fn, "wb").write(fileUpload.file.read()) | |
response = subprocess.check_output( | |
[ | |
"./scripts/verify.sh", | |
fileUpload.filename, | |
] | |
) | |
logging.warning(response) | |
response_list = response.splitlines() | |
c2pa_string = str(response_list[0]) | |
c2pa = c2pa_string.split(":", 1) | |
c2pa = c2pa[1].strip(" ").strip("'") | |
watermark_string = str(response_list[1]) | |
watermark = watermark_string.split(":", 1) | |
watermark = watermark[1].strip(" ").strip("'") | |
original_media_string = str(response_list[2]) | |
original_media = original_media_string.split(":", 1) | |
original_media = original_media[1].strip(" ").strip("'") | |
if original_media != 'n/a': | |
original_media_extension = file_extension(original_media) | |
logging.warning(original_media_extension) | |
filename = str(uuid.uuid4()) + original_media_extension | |
response = subprocess.check_output( | |
[ | |
"cp", | |
original_media, | |
"static/" + filename, | |
] | |
) | |
original_media = filename | |
return {"response": fileUpload.filename, "contains_c2pa" : c2pa, "contains_watermark" : watermark, "original_media" : original_media} | |
app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
def index() -> FileResponse: | |
return FileResponse(path="/app/static/index.html", media_type="text/html") | |