|
from dataclasses import dataclass, field |
|
import os |
|
import json |
|
import logging |
|
from argparse import Namespace |
|
from typing import List, Literal, Optional, Union |
|
from pydantic import AnyHttpUrl, BaseSettings, HttpUrl, validator, BaseModel |
|
|
|
|
|
CURRENT_DIR_PATH = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
|
|
class RequestDataStructure(BaseModel): |
|
input_text: List[str] = [""] |
|
uuid: Optional[int] |
|
|
|
|
|
input_image: Optional[str] |
|
skip_steps: Optional[int] |
|
clip_guidance_scale: Optional[int] |
|
init_scale: Optional[int] |
|
|
|
|
|
@dataclass |
|
class APIConfig: |
|
|
|
|
|
SERVER_HOST: AnyHttpUrl = "127.0.0.1" |
|
SERVER_PORT: int = 8990 |
|
SERVER_NAME: str = "" |
|
PROJECT_NAME: str = "" |
|
API_PREFIX_STR: str = "/api" |
|
|
|
|
|
API_method: Literal["POST","GET","PUT","OPTIONS","WEBSOCKET","PATCH","DELETE","TRACE","CONNECT"] = "POST" |
|
API_path: str = "/TextClassification" |
|
API_tags: List[str] = field(default_factory = lambda: [""]) |
|
|
|
|
|
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = field(default_factory = lambda: ["*"]) |
|
allow_credentials: bool = True |
|
allow_methods: List[str] = field(default_factory = lambda: ["*"]) |
|
allow_headers: List[str] = field(default_factory = lambda: ["*"]) |
|
|
|
|
|
log_file_path: str = "" |
|
log_level: str = "INFO" |
|
|
|
|
|
pipeline_type: str = "" |
|
model_name: str = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_config(self, args:Namespace) -> None: |
|
|
|
|
|
with open(CURRENT_DIR_PATH + "/" + args.config_path, "r") as jsonfile: |
|
config = json.load(jsonfile) |
|
|
|
server_config = config["SERVER"] |
|
logging_config = config["LOGGING"] |
|
pipeline_config = config["PIPELINE"] |
|
|
|
|
|
self.SERVER_HOST: AnyHttpUrl = server_config["SERVER_HOST"] |
|
self.SERVER_PORT: int = server_config["SERVER_PORT"] |
|
self.SERVER_NAME: str = server_config["SERVER_NAME"] |
|
self.PROJECT_NAME: str = server_config["PROJECT_NAME"] |
|
self.API_PREFIX_STR: str = server_config["API_PREFIX_STR"] |
|
|
|
|
|
self.API_method: Literal["POST","GET","PUT","OPTIONS","WEBSOCKET","PATCH","DELETE","TRACE","CONNECT"] = server_config["API_method"] |
|
self.API_path: str = server_config["API_path"] |
|
self.API_tags: List[str] = server_config["API_tags"] |
|
|
|
|
|
self.BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = server_config["BACKEND_CORS_ORIGINS"] |
|
self.allow_credentials: bool = server_config["allow_credentials"] |
|
self.allow_methods: List[str] = server_config["allow_methods"] |
|
self.allow_headers: List[str] = server_config["allow_headers"] |
|
|
|
|
|
self.log_file_path: str = logging_config["log_file_path"] |
|
self.log_level: str = logging_config["log_level"] |
|
|
|
|
|
self.pipeline_type: str = pipeline_config["pipeline_type"] |
|
self.model_name: str = pipeline_config["model_name"] |
|
|
|
|
|
self.model_settings: dict = pipeline_config["model_settings"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_logger(logger, user_config: APIConfig): |
|
|
|
|
|
|
|
logger.setLevel(getattr(logging, user_config.log_level, "INFO")) |
|
ch = logging.StreamHandler() |
|
|
|
if(user_config.log_file_path == ""): |
|
fh = logging.FileHandler(filename = CURRENT_DIR_PATH + "/" + user_config.SERVER_NAME + ".log") |
|
elif(".log" not in user_config.log_file_path[-5:-1]): |
|
fh = logging.FileHandler(filename = user_config.log_file_path + "/" + user_config.SERVER_NAME + ".log") |
|
else: |
|
fh = logging.FileHandler(filename = user_config.log_file_path) |
|
|
|
|
|
formatter = logging.Formatter( |
|
"%(asctime)s - %(module)s - %(funcName)s - line:%(lineno)d - %(levelname)s - %(message)s" |
|
) |
|
|
|
ch.setFormatter(formatter) |
|
fh.setFormatter(formatter) |
|
logger.addHandler(ch) |
|
logger.addHandler(fh) |
|
|
|
return logger |
|
|
|
user_config = APIConfig() |
|
api_logger = logging.getLogger() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|