import traceback from typing import Dict, List, Optional from requests.adapters import Retry, HTTPAdapter import requests from pydash import includes from internals.data.task import Task from internals.util.config import api_endpoint, api_headers from internals.util.slack import Slack class RetryRequest: def __new__(cls): obj = Retry(total=5, backoff_factor=2, status_forcelist=[500, 502, 503, 504]) session = requests.Session() session.mount("https://", HTTPAdapter(max_retries=obj)) return session def updateSource(sourceId, userId, state): url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}" headers = { "Content-Type": "application/json", "user-id": str(userId), **api_headers(), } data = {"state": state} try: with RetryRequest() as session: response = session.patch(url, headers=headers, json=data, timeout=10) except requests.exceptions.Timeout: print("Request timed out while updating source") except requests.exceptions.RequestException as e: print(f"Error while updating source: {e}") return def saveGeneratedImages(sourceId, userId, has_nsfw: bool): url = ( api_endpoint() + "/autodraft-crecoai/source/" + str(sourceId) + "/generatedImages" ) headers = { "Content-Type": "application/json", "user-id": str(userId), **api_headers(), } data = {"state": "ACTIVE", "has_nsfw": has_nsfw} try: with RetryRequest() as session: session.patch(url, headers=headers, json=data) # print("save generation response", response) except requests.exceptions.Timeout: print("Request timed out while saving image") except requests.exceptions.RequestException as e: print("Failed to mark source as active: ", e) return return def getStyles() -> Optional[Dict]: url = api_endpoint() + "/autodraft-crecoai/style" try: with RetryRequest() as session: response = session.get( url, timeout=10, headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()}, ) return response.json() except requests.exceptions.Timeout: print("Request timed out while fetching styles") except requests.exceptions.RequestException as e: print(f"Error while fetching styles: {e}") raise e return None def getCharacters(model_id: str) -> Optional[List]: url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id) try: with RetryRequest() as session: response = session.get(url, timeout=10, headers=api_headers()) response = response.json() response = response["data"]["characters"] return response except requests.exceptions.Timeout: print("Request timed out while fetching characters") except Exception as e: print(f"Error while fetching characters: {e}") return None def update_db_source_failed(sourceId, userId): updateSource(sourceId, userId, "FAILED") def update_db(func): def caller(*args, **kwargs): if type(args[0]) is not Task: raise Exception("First argument must be a Task object") task = args[0] try: updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS") rargs = func(*args, **kwargs) has_nsfw = rargs.get("has_nsfw", False) updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED") saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw) return rargs except Exception as e: print("Error processing image: {}".format(str(e))) traceback.print_exc() slack = Slack() slack.error_alert(task, e) updateSource(task.get_sourceId(), task.get_userId(), "FAILED") return caller