CM2000112 / internals /data /dataAccessor.py
jayparmr's picture
Upload folder using huggingface_hub
10230ea
raw
history blame
4.15 kB
import traceback
from typing import Dict, List, Optional
from requests.adapters import Retry, HTTPAdapter
import requests
from pydash import includes
from internals.data.task import Task
from internals.util.config import api_endpoint, api_headers
from internals.util.slack import Slack
class RetryRequest:
def __new__(cls):
obj = Retry(total=5, backoff_factor=2, status_forcelist=[500, 502, 503, 504])
session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=obj))
return session
def updateSource(sourceId, userId, state):
print("update source is called")
url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
headers = {
"Content-Type": "application/json",
"user-id": str(userId),
**api_headers(),
}
data = {"state": state}
try:
with RetryRequest() as session:
response = session.patch(url, headers=headers, json=data, timeout=10)
print("update source response", response)
except requests.exceptions.Timeout:
print("Request timed out while updating source")
except requests.exceptions.RequestException as e:
print(f"Error while updating source: {e}")
return
def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
print("save generation called")
url = (
api_endpoint()
+ "/autodraft-crecoai/source/"
+ str(sourceId)
+ "/generatedImages"
)
headers = {
"Content-Type": "application/json",
"user-id": str(userId),
**api_headers(),
}
data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
try:
with RetryRequest() as session:
session.patch(url, headers=headers, json=data)
# print("save generation response", response)
except requests.exceptions.Timeout:
print("Request timed out while saving image")
except requests.exceptions.RequestException as e:
print("Failed to mark source as active: ", e)
return
return
def getStyles() -> Optional[Dict]:
url = api_endpoint() + "/autodraft-crecoai/style"
print(url)
try:
with RetryRequest() as session:
response = session.get(
url,
timeout=10,
headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
)
return response.json()
except requests.exceptions.Timeout:
print("Request timed out while fetching styles")
except requests.exceptions.RequestException as e:
print(f"Error while fetching styles: {e}")
raise e
return None
def getCharacters(model_id: str) -> Optional[List]:
url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
try:
with RetryRequest() as session:
response = session.get(url, timeout=10, headers=api_headers())
response = response.json()
response = response["data"]["characters"]
return response
except requests.exceptions.Timeout:
print("Request timed out while fetching characters")
except Exception as e:
print(f"Error while fetching characters: {e}")
return None
def update_db_source_failed(sourceId, userId):
updateSource(sourceId, userId, "FAILED")
def update_db(func):
def caller(*args, **kwargs):
if type(args[0]) is not Task:
raise Exception("First argument must be a Task object")
task = args[0]
try:
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
rargs = func(*args, **kwargs)
has_nsfw = rargs.get("has_nsfw", False)
updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED")
saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw)
return rargs
except Exception as e:
print("Error processing image: {}".format(str(e)))
traceback.print_exc()
slack = Slack()
slack.error_alert(task, e)
updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
return caller