CM2000112 / internals /data /dataAccessor.py
jayparmr's picture
Upload folder using huggingface_hub
a3f5c82
raw
history blame
4.01 kB
import traceback
from typing import Dict, List, Optional
from requests.adapters import Retry, HTTPAdapter
import requests
from pydash import includes
from internals.data.task import Task
from internals.util.config import api_endpoint, api_headers
from internals.util.slack import Slack
class RetryRequest:
def __new__(cls):
obj = Retry(total=5, backoff_factor=2, status_forcelist=[500, 502, 503, 504])
session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=obj))
return session
def updateSource(sourceId, userId, state):
url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
headers = {
"Content-Type": "application/json",
"user-id": str(userId),
**api_headers(),
}
data = {"state": state}
try:
with RetryRequest() as session:
response = session.patch(url, headers=headers, json=data, timeout=10)
except requests.exceptions.Timeout:
print("Request timed out while updating source")
except requests.exceptions.RequestException as e:
print(f"Error while updating source: {e}")
return
def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
url = (
api_endpoint()
+ "/autodraft-crecoai/source/"
+ str(sourceId)
+ "/generatedImages"
)
headers = {
"Content-Type": "application/json",
"user-id": str(userId),
**api_headers(),
}
data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
try:
with RetryRequest() as session:
session.patch(url, headers=headers, json=data)
# print("save generation response", response)
except requests.exceptions.Timeout:
print("Request timed out while saving image")
except requests.exceptions.RequestException as e:
print("Failed to mark source as active: ", e)
return
return
def getStyles() -> Optional[Dict]:
url = api_endpoint() + "/autodraft-crecoai/style"
try:
with RetryRequest() as session:
response = session.get(
url,
timeout=10,
headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
)
return response.json()
except requests.exceptions.Timeout:
print("Request timed out while fetching styles")
except requests.exceptions.RequestException as e:
print(f"Error while fetching styles: {e}")
raise e
return None
def getCharacters(model_id: str) -> Optional[List]:
url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
try:
with RetryRequest() as session:
response = session.get(url, timeout=10, headers=api_headers())
response = response.json()
response = response["data"]["characters"]
return response
except requests.exceptions.Timeout:
print("Request timed out while fetching characters")
except Exception as e:
print(f"Error while fetching characters: {e}")
return None
def update_db_source_failed(sourceId, userId):
updateSource(sourceId, userId, "FAILED")
def update_db(func):
def caller(*args, **kwargs):
if type(args[0]) is not Task:
raise Exception("First argument must be a Task object")
task = args[0]
try:
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
rargs = func(*args, **kwargs)
has_nsfw = rargs.get("has_nsfw", False)
updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED")
saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw)
return rargs
except Exception as e:
print("Error processing image: {}".format(str(e)))
traceback.print_exc()
slack = Slack()
slack.error_alert(task, e)
updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
return caller