File size: 2,417 Bytes
342d3dd
 
 
 
 
 
 
 
 
bacf369
342d3dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bacf369
 
 
342d3dd
 
bacf369
342d3dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
from typing import List

from fastapi import APIRouter

from src.libs.image import save_img, generate_img_index
from src.libs.model import CganCols, get_model
from src.libs.s3 import s3client
from src.models.generate import GenerateResult, ImageResult
from src.models.main import User, Method, UserIndex

IMAGE_STORE_PATH = os.path.abspath("./src/store")
BUCKET_NAME = "pimthaigans-image-container"

# just make sure to have IMAGE_STORE_PATH folder created
if not os.path.exists(IMAGE_STORE_PATH):
    os.makedirs(IMAGE_STORE_PATH)

router = APIRouter(
    prefix="/generate",
    tags=["Generate"],
    responses={404: {"description": "Not found"}},
)

model = CganCols()


@router.get("/")
async def info():
    return {"info": "This is the generate endpoint"}


@router.get("/status")
async def status():
    return {"status": "OK"}


@router.post("/")
async def generate(user: UserIndex) -> GenerateResult:
    if user.method == Method.index:
        result: GenerateResult = await generate_index(user.user, user.index)
        return result

    result: GenerateResult = await generate_all(user.user)
    return result


async def generate_index(user: User, index: int) -> GenerateResult:
    s3 = s3client()

    img_detail = s3uploadimage(user, s3, index)
    result: List[ImageResult] = [img_detail]

    s3.close()

    return GenerateResult(user=user, method=Method.index, result=result)


async def generate_all(user: User):
    s3 = s3client()
    result: List[ImageResult] = []

    for index in range(0, 88):
        img_detail = s3uploadimage(user, s3, index)
        result.append(img_detail)

    s3.close()

    return GenerateResult(user=user, method=Method.all, result=result)


def s3uploadimage(user, s3, index):
    output_path = os.path.join(
        IMAGE_STORE_PATH, f"{user.uuid}-{str(index).zfill(2)}.png")
    used_model = model.model_cols[get_model(index)]
    image = generate_img_index(reloaded_model=used_model, index=index % 11)
    save_img(image, output_path)

    s3_path: str = f"{user.uuid}/{str(index).zfill(2)}.png"
    s3.upload_file(output_path, BUCKET_NAME, s3_path)
    image_url = f'https://{BUCKET_NAME}.s3.amazonaws.com/{s3_path}'

    img_detail = ImageResult(index=index,
                             image_url=image_url)
    os.remove(output_path)
    return img_detail