|
import traceback |
|
from typing import Dict, List, Optional |
|
|
|
from requests.adapters import Retry, HTTPAdapter |
|
import requests |
|
from pydash import includes |
|
|
|
from internals.data.task import Task |
|
from internals.util.config import api_endpoint, api_headers |
|
from internals.util.slack import Slack |
|
|
|
|
|
class RetryRequest: |
|
def __new__(cls): |
|
obj = Retry(total=5, backoff_factor=2, status_forcelist=[500, 502, 503, 504]) |
|
session = requests.Session() |
|
session.mount("https://", HTTPAdapter(max_retries=obj)) |
|
return session |
|
|
|
|
|
def updateSource(sourceId, userId, state): |
|
url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}" |
|
headers = { |
|
"Content-Type": "application/json", |
|
"user-id": str(userId), |
|
**api_headers(), |
|
} |
|
|
|
data = {"state": state} |
|
|
|
try: |
|
with RetryRequest() as session: |
|
response = session.patch(url, headers=headers, json=data, timeout=10) |
|
except requests.exceptions.Timeout: |
|
print("Request timed out while updating source") |
|
except requests.exceptions.RequestException as e: |
|
print(f"Error while updating source: {e}") |
|
|
|
return |
|
|
|
|
|
def saveGeneratedImages(sourceId, userId, has_nsfw: bool): |
|
url = ( |
|
api_endpoint() |
|
+ "/autodraft-crecoai/source/" |
|
+ str(sourceId) |
|
+ "/generatedImages" |
|
) |
|
headers = { |
|
"Content-Type": "application/json", |
|
"user-id": str(userId), |
|
**api_headers(), |
|
} |
|
data = {"state": "ACTIVE", "has_nsfw": has_nsfw} |
|
|
|
try: |
|
with RetryRequest() as session: |
|
session.patch(url, headers=headers, json=data) |
|
|
|
except requests.exceptions.Timeout: |
|
print("Request timed out while saving image") |
|
except requests.exceptions.RequestException as e: |
|
print("Failed to mark source as active: ", e) |
|
return |
|
return |
|
|
|
|
|
def getStyles() -> Optional[Dict]: |
|
url = api_endpoint() + "/autodraft-crecoai/style" |
|
try: |
|
with RetryRequest() as session: |
|
response = session.get( |
|
url, |
|
timeout=10, |
|
headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()}, |
|
) |
|
return response.json() |
|
except requests.exceptions.Timeout: |
|
print("Request timed out while fetching styles") |
|
except requests.exceptions.RequestException as e: |
|
print(f"Error while fetching styles: {e}") |
|
raise e |
|
return None |
|
|
|
|
|
def getCharacters(model_id: str) -> Optional[List]: |
|
url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id) |
|
try: |
|
with RetryRequest() as session: |
|
response = session.get(url, timeout=10, headers=api_headers()) |
|
response = response.json() |
|
response = response["data"]["characters"] |
|
return response |
|
except requests.exceptions.Timeout: |
|
print("Request timed out while fetching characters") |
|
except Exception as e: |
|
print(f"Error while fetching characters: {e}") |
|
return None |
|
|
|
|
|
def update_db_source_failed(sourceId, userId): |
|
updateSource(sourceId, userId, "FAILED") |
|
|
|
|
|
def update_db(func): |
|
def caller(*args, **kwargs): |
|
if type(args[0]) is not Task: |
|
raise Exception("First argument must be a Task object") |
|
task = args[0] |
|
try: |
|
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS") |
|
rargs = func(*args, **kwargs) |
|
has_nsfw = rargs.get("has_nsfw", False) |
|
updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED") |
|
saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw) |
|
return rargs |
|
except Exception as e: |
|
print("Error processing image: {}".format(str(e))) |
|
traceback.print_exc() |
|
slack = Slack() |
|
slack.error_alert(task, e) |
|
updateSource(task.get_sourceId(), task.get_userId(), "FAILED") |
|
|
|
return caller |
|
|