|
from fastapi import FastAPI, UploadFile |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import FileResponse |
|
|
|
import subprocess |
|
import os |
|
import json |
|
import uuid |
|
import logging |
|
|
|
import torch |
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
DPMSolverMultistepScheduler, |
|
EulerDiscreteScheduler, |
|
) |
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.get("/generate") |
|
def generate_image(prompt, model): |
|
torch.cuda.empty_cache() |
|
|
|
modelArray = model.split(",") |
|
modelName = modelArray[0] |
|
modelVersion = modelArray[1] |
|
|
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
str(modelName), torch_dtype=torch.float16 |
|
) |
|
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) |
|
pipeline = pipeline.to("cuda") |
|
|
|
image = pipeline(prompt, num_inference_steps=50, height=512, width=512).images[0] |
|
|
|
filename = str(uuid.uuid4()) + ".jpg" |
|
image.save(filename) |
|
|
|
assertion = { |
|
"assertions": [ |
|
{ |
|
"label": "com.truepic.custom.ai", |
|
"data": { |
|
"model_name": modelName, |
|
"model_version": modelVersion, |
|
"prompt": prompt, |
|
}, |
|
} |
|
] |
|
} |
|
|
|
json_object = json.dumps(assertion) |
|
|
|
subprocess.check_output( |
|
[ |
|
"./truepic", |
|
"sign", |
|
filename, |
|
"--assertions-inline", |
|
json_object, |
|
"--output", |
|
(os.getcwd() + "/static/" + filename), |
|
] |
|
) |
|
|
|
return {"response": filename} |
|
|
|
@app.post("/verify") |
|
def verify_image(fileUpload: UploadFile): |
|
logging.warning("in verify") |
|
logging.warning(fileUpload.filename) |
|
|
|
|
|
|
|
|
|
if fileUpload.filename: |
|
|
|
fn = os.path.basename(fileUpload.filename) |
|
|
|
|
|
open(fn, 'wb').write(fileUpload.file.read()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
subprocess.check_output( |
|
[ |
|
"./scripts/sign.sh", |
|
fileUpload.filename, |
|
] |
|
) |
|
|
|
return {"response": fileUpload.filename} |
|
|
|
@app.post("/sign") |
|
def sign_image(fileUpload: UploadFile): |
|
logging.warning("in verify") |
|
logging.warning(fileUpload.filename) |
|
|
|
|
|
|
|
|
|
if fileUpload.filename: |
|
|
|
fn = os.path.basename(fileUpload.filename) |
|
|
|
|
|
open(fn, 'wb').write(fileUpload.file.read()) |
|
|
|
return {"response": fileUpload.filename} |
|
|
|
|
|
app.mount("/", StaticFiles(directory="static", html=True), name="static") |
|
|
|
|
|
@app.get("/") |
|
def index() -> FileResponse: |
|
return FileResponse(path="/app/static/index.html", media_type="text/html") |
|
|