PokeGen / app.py
Ron Au
feat(endpoint): Change `create/task` from GET to POST
42422ba
raw
history blame
3 kB
from time import time
from statistics import mean
from fastapi import BackgroundTasks, FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from modules.details import rand_details
from modules.inference import generate_image
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
tasks = {}
class NewTask(BaseModel):
prompt = "покемон"
def get_place_in_queue(task_id):
queued_tasks = list(task for task in tasks.values()
if task["status"] == "queued" or task["status"] == "processing")
queued_tasks.sort(key=lambda task: task["created_at"])
queued_task_ids = list(task["task_id"] for task in queued_tasks)
try:
return queued_task_ids.index(task_id) + 1
except:
return 0
def calculate_eta(task_id):
total_durations = list(task["completed_at"] - task["started_at"]
for task in tasks.values() if "completed_at" in task)
initial_place_in_queue = tasks[task_id]["initial_place_in_queue"]
if len(total_durations):
eta = initial_place_in_queue * mean(total_durations)
else:
eta = initial_place_in_queue * 40
return round(eta, 1)
def process_task(task_id):
if 'processing' in list(task['status'] for task in tasks.values()):
return
tasks[task_id]["status"] = "processing"
tasks[task_id]["started_at"] = time()
try:
tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
except Exception as ex:
tasks[task_id]["status"] = "failed"
tasks[task_id]["error"] = repr(ex)
else:
tasks[task_id]["status"] = "completed"
finally:
tasks[task_id]["completed_at"] = time()
queued_tasks = list(task for task in tasks.values() if task["status"] == "queued")
if queued_tasks:
print(f"Tasks remaining: {len(queued_tasks)}")
process_task(queued_tasks[0]["task_id"])
@app.head('/')
@app.get('/')
def index():
return FileResponse(path="static/index.html", media_type="text/html")
@app.get('/details')
def generate_details():
return rand_details()
@app.post('/task/create')
def create_task(background_tasks: BackgroundTasks, new_task: NewTask):
created_at = time()
task_id = f"{str(created_at)}_{new_task.prompt}"
tasks[task_id] = {
"task_id": task_id,
"created_at": created_at,
"prompt": new_task.prompt,
"status": "queued",
"poll_count": 0,
}
tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id)
background_tasks.add_task(process_task, task_id)
return tasks[task_id]
@app.get('/task/poll')
def poll_task(task_id: str):
tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id)
tasks[task_id]["eta"] = calculate_eta(task_id)
tasks[task_id]["poll_count"] += 1
return tasks[task_id]