CM2000112 / internals /data /dataAccessor.py
jayparmr's picture
Upload folder using huggingface_hub
7fbdac4
raw
history blame
4.1 kB
import traceback
from typing import Dict, List, Optional
import requests
from pydash import includes
from requests.adapters import HTTPAdapter, Retry
from internals.data.task import Task
from internals.util.config import api_endpoint, api_headers
from internals.util.slack import Slack
class RetryRequest:
def __new__(cls):
obj = Retry(total=5, backoff_factor=2, status_forcelist=[500, 502, 503, 504])
session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=obj))
return session
def updateSource(sourceId, userId, state):
url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
headers = {
"Content-Type": "application/json",
"user-id": str(userId),
**api_headers(),
}
data = {"state": state}
try:
with RetryRequest() as session:
response = session.patch(url, headers=headers, json=data, timeout=10)
except requests.exceptions.Timeout:
print("Request timed out while updating source")
except requests.exceptions.RequestException as e:
print(f"Error while updating source: {e}")
return
def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
url = (
api_endpoint()
+ "/autodraft-crecoai/source/"
+ str(sourceId)
+ "/generatedImages"
)
headers = {
"Content-Type": "application/json",
"user-id": str(userId),
**api_headers(),
}
data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
try:
with RetryRequest() as session:
session.patch(url, headers=headers, json=data)
# print("save generation response", response)
except requests.exceptions.Timeout:
print("Request timed out while saving image")
except requests.exceptions.RequestException as e:
print("Failed to mark source as active: ", e)
return
return
def getStyles() -> Optional[Dict]:
url = api_endpoint() + "/autodraft-crecoai/style"
try:
with RetryRequest() as session:
response = session.get(
url,
timeout=10,
headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
)
return response.json()
except requests.exceptions.Timeout:
print("Request timed out while fetching styles")
except requests.exceptions.RequestException as e:
print(f"Error while fetching styles: {e}")
raise e
return None
def getCharacters(model_id: str) -> Optional[List]:
url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
try:
with RetryRequest() as session:
response = session.get(url, timeout=10, headers=api_headers())
response = response.json()
response = response["data"]["characters"]
return response
except requests.exceptions.Timeout:
print("Request timed out while fetching characters")
except Exception as e:
print(f"Error while fetching characters: {e}")
return None
def update_db_source_failed(sourceId, userId):
updateSource(sourceId, userId, "FAILED")
def update_db(func):
def caller(*args, **kwargs):
task = None
for arg in args:
if type(arg) is Task:
task = arg
break
if task is None:
raise Exception("First argument must be a Task object")
try:
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
rargs = func(*args, **kwargs)
has_nsfw = rargs.get("has_nsfw", False)
updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED")
saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw)
return rargs
except Exception as e:
print("Error processing image: {}".format(str(e)))
traceback.print_exc()
slack = Slack()
slack.error_alert(task, e)
updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
return caller