jayparmr commited on
Commit
a3f5c82
1 Parent(s): 0ec6d80

Upload folder using huggingface_hub

Browse files
inference2.py CHANGED
@@ -4,7 +4,7 @@ from io import BytesIO
4
  import torch
5
 
6
  import internals.util.prompt as prompt_util
7
- from internals.data.dataAccessor import update_db
8
  from internals.data.task import ModelType, Task, TaskType
9
  from internals.pipelines.controlnets import ControlNet
10
  from internals.pipelines.high_res import HighRes
@@ -194,7 +194,10 @@ def replace_bg(task: Task):
194
  def upscale_image(task: Task):
195
  output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
196
  out_img = None
197
- if task.get_modelType() == ModelType.ANIME:
 
 
 
198
  print("Using Anime model")
199
  out_img = upscaler.upscale_anime(
200
  image=task.get_imageUrl(),
@@ -297,4 +300,5 @@ def predict_fn(data, pipe):
297
  print(f"Error: {e}")
298
  slack.error_alert(task, e)
299
  controlnet.cleanup()
 
300
  return None
 
4
  import torch
5
 
6
  import internals.util.prompt as prompt_util
7
+ from internals.data.dataAccessor import update_db, update_db_source_failed
8
  from internals.data.task import ModelType, Task, TaskType
9
  from internals.pipelines.controlnets import ControlNet
10
  from internals.pipelines.high_res import HighRes
 
194
  def upscale_image(task: Task):
195
  output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
196
  out_img = None
197
+ if (
198
+ task.get_modelType() == ModelType.ANIME
199
+ or task.get_modelType() == ModelType.COMIC
200
+ ):
201
  print("Using Anime model")
202
  out_img = upscaler.upscale_anime(
203
  image=task.get_imageUrl(),
 
300
  print(f"Error: {e}")
301
  slack.error_alert(task, e)
302
  controlnet.cleanup()
303
+ update_db_source_failed(task.get_sourceId(), task.get_userId())
304
  return None
internals/data/dataAccessor.py CHANGED
@@ -19,7 +19,6 @@ class RetryRequest:
19
 
20
 
21
  def updateSource(sourceId, userId, state):
22
- print("update source is called")
23
  url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
24
  headers = {
25
  "Content-Type": "application/json",
@@ -32,7 +31,6 @@ def updateSource(sourceId, userId, state):
32
  try:
33
  with RetryRequest() as session:
34
  response = session.patch(url, headers=headers, json=data, timeout=10)
35
- print("update source response", response)
36
  except requests.exceptions.Timeout:
37
  print("Request timed out while updating source")
38
  except requests.exceptions.RequestException as e:
@@ -42,7 +40,6 @@ def updateSource(sourceId, userId, state):
42
 
43
 
44
  def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
45
- print("save generation called")
46
  url = (
47
  api_endpoint()
48
  + "/autodraft-crecoai/source/"
@@ -70,7 +67,6 @@ def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
70
 
71
  def getStyles() -> Optional[Dict]:
72
  url = api_endpoint() + "/autodraft-crecoai/style"
73
- print(url)
74
  try:
75
  with RetryRequest() as session:
76
  response = session.get(
 
19
 
20
 
21
  def updateSource(sourceId, userId, state):
 
22
  url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
23
  headers = {
24
  "Content-Type": "application/json",
 
31
  try:
32
  with RetryRequest() as session:
33
  response = session.patch(url, headers=headers, json=data, timeout=10)
 
34
  except requests.exceptions.Timeout:
35
  print("Request timed out while updating source")
36
  except requests.exceptions.RequestException as e:
 
40
 
41
 
42
  def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
 
43
  url = (
44
  api_endpoint()
45
  + "/autodraft-crecoai/source/"
 
67
 
68
  def getStyles() -> Optional[Dict]:
69
  url = api_endpoint() + "/autodraft-crecoai/style"
 
70
  try:
71
  with RetryRequest() as session:
72
  response = session.get(
internals/data/task.py CHANGED
@@ -26,6 +26,10 @@ class ModelType(Enum):
26
  ANIME = 10001
27
  COMIC = 10002
28
 
 
 
 
 
29
 
30
  class Task:
31
  def __init__(self, data):
@@ -156,6 +160,9 @@ class Task:
156
  def get_high_res_fix(self) -> bool:
157
  return self.__data.get("high_res_fix", False)
158
 
 
 
 
159
  def get_raw(self) -> dict:
160
  return self.__data.copy()
161
 
 
26
  ANIME = 10001
27
  COMIC = 10002
28
 
29
+ @classmethod
30
+ def _missing_(cls, value):
31
+ return cls.REAL
32
+
33
 
34
  class Task:
35
  def __init__(self, data):
 
160
  def get_high_res_fix(self) -> bool:
161
  return self.__data.get("high_res_fix", False)
162
 
163
+ def get_base_dimension(self) -> int:
164
+ return self.__data.get("base_dimension", 512)
165
+
166
  def get_raw(self) -> dict:
167
  return self.__data.copy()
168
 
internals/pipelines/high_res.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
 
6
  from internals.data.result import Result
7
  from internals.pipelines.commons import AbstractPipeline, Img2Img
8
- from internals.util.config import get_model_dir
9
 
10
 
11
  class HighRes(AbstractPipeline):
@@ -42,7 +42,7 @@ class HighRes(AbstractPipeline):
42
 
43
  @staticmethod
44
  def get_intermediate_dimension(target_width: int, target_height: int):
45
- def_size = 1024
46
 
47
  desired_pixel_count = def_size * def_size
48
  actual_pixel_count = target_width * target_height
 
5
 
6
  from internals.data.result import Result
7
  from internals.pipelines.commons import AbstractPipeline, Img2Img
8
+ from internals.util.config import get_model_dir, get_base_dimension
9
 
10
 
11
  class HighRes(AbstractPipeline):
 
42
 
43
  @staticmethod
44
  def get_intermediate_dimension(target_width: int, target_height: int):
45
+ def_size = get_base_dimension()
46
 
47
  desired_pixel_count = def_size * def_size
48
  actual_pixel_count = target_width * target_height
internals/pipelines/upscaler.py CHANGED
@@ -87,7 +87,7 @@ class Upscaler:
87
  num_in_ch=3,
88
  num_out_ch=3,
89
  num_feat=64,
90
- num_block=23,
91
  num_grow_ch=32,
92
  scale=4,
93
  )
 
87
  num_in_ch=3,
88
  num_out_ch=3,
89
  num_feat=64,
90
+ num_block=6,
91
  num_grow_ch=32,
92
  scale=4,
93
  )
internals/util/config.py CHANGED
@@ -13,6 +13,7 @@ root_dir = ""
13
  model_config = None
14
  hf_token = "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"
15
  hf_cache_dir = "/tmp/hf_hub"
 
16
 
17
  num_return_sequences = 4 # the number of results to generate
18
 
@@ -40,7 +41,7 @@ def set_model_config(config: ModelConfig):
40
 
41
 
42
  def set_configs_from_task(task: Task):
43
- global env, nsfw_threshold, nsfw_access, access_token
44
  name = task.get_queue_name()
45
  if name.startswith("gamma"):
46
  env = "gamma"
@@ -49,6 +50,7 @@ def set_configs_from_task(task: Task):
49
  nsfw_threshold = task.get_nsfw_threshold()
50
  nsfw_access = task.can_access_nsfw()
51
  access_token = task.get_access_token()
 
52
 
53
 
54
  def get_model_dir():
@@ -61,6 +63,11 @@ def get_inpaint_model_path():
61
  return model_config.base_inpaint_model_path # pyright: ignore
62
 
63
 
 
 
 
 
 
64
  def get_is_sdxl():
65
  global model_config
66
  return model_config.is_sdxl # pyright: ignore
 
13
  model_config = None
14
  hf_token = "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"
15
  hf_cache_dir = "/tmp/hf_hub"
16
+ base_dimension = 512 # needed for high res
17
 
18
  num_return_sequences = 4 # the number of results to generate
19
 
 
41
 
42
 
43
  def set_configs_from_task(task: Task):
44
+ global env, nsfw_threshold, nsfw_access, access_token, base_dimension
45
  name = task.get_queue_name()
46
  if name.startswith("gamma"):
47
  env = "gamma"
 
50
  nsfw_threshold = task.get_nsfw_threshold()
51
  nsfw_access = task.can_access_nsfw()
52
  access_token = task.get_access_token()
53
+ base_dimension = task.get_base_dimension()
54
 
55
 
56
  def get_model_dir():
 
63
  return model_config.base_inpaint_model_path # pyright: ignore
64
 
65
 
66
+ def get_base_dimension():
67
+ global base_dimension
68
+ return base_dimension
69
+
70
+
71
  def get_is_sdxl():
72
  global model_config
73
  return model_config.is_sdxl # pyright: ignore