jayparmr commited on
Commit
9387217
·
1 Parent(s): e6a4021

Upload folder using huggingface_hub

Browse files
internals/data/task.py CHANGED
@@ -145,8 +145,8 @@ class Task:
145
  def get_high_res_fix(self) -> bool:
146
  return self.__data.get("high_res_fix", False)
147
 
148
- def get_base_dimension(self) -> int:
149
- return self.__data.get("base_dimension", 512)
150
 
151
  def get_raw(self) -> dict:
152
  return self.__data.copy()
 
145
  def get_high_res_fix(self) -> bool:
146
  return self.__data.get("high_res_fix", False)
147
 
148
+ def get_base_dimension(self):
149
+ return self.__data.get("base_dimension", None)
150
 
151
  def get_raw(self) -> dict:
152
  return self.__data.copy()
internals/util/config.py CHANGED
@@ -16,6 +16,7 @@ hf_token = base64.b64decode(
16
  b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA=="
17
  ).decode()
18
  hf_cache_dir = "/tmp/hf_hub"
 
19
  base_dimension = 512 # needed for high res
20
 
21
  num_return_sequences = 4 # the number of results to generate
@@ -67,8 +68,10 @@ def get_inpaint_model_path():
67
 
68
 
69
  def get_base_dimension():
70
- global base_dimension
71
- return base_dimension
 
 
72
 
73
 
74
  def get_is_sdxl():
 
16
  b"aGZfVFZCTHNUam1tT3d6T0h1dlVZWkhEbEZ4WVdOSUdGamVCbA=="
17
  ).decode()
18
  hf_cache_dir = "/tmp/hf_hub"
19
+
20
  base_dimension = 512 # needed for high res
21
 
22
  num_return_sequences = 4 # the number of results to generate
 
68
 
69
 
70
  def get_base_dimension():
71
+ global global_base_dimension, base_dimension
72
+ if base_dimension:
73
+ return base_dimension
74
+ return model_config.base_dimension # pyright: ignore
75
 
76
 
77
  def get_is_sdxl():
internals/util/model_loader.py CHANGED
@@ -15,6 +15,7 @@ class ModelConfig:
15
  base_model_path: str
16
  base_inpaint_model_path: str
17
  is_sdxl: bool = False
 
18
 
19
 
20
  def load_model_from_config(path):
@@ -25,10 +26,12 @@ def load_model_from_config(path):
25
  model_path = config.get("model_path", path)
26
  inpaint_model_path = config.get("inpaint_model_path", path)
27
  is_sdxl = config.get("is_sdxl", False)
 
28
 
29
  m_config.base_model_path = model_path
30
  m_config.base_inpaint_model_path = inpaint_model_path
31
  m_config.is_sdxl = is_sdxl
 
32
 
33
  #
34
  # if config.get("model_type") == "huggingface":
 
15
  base_model_path: str
16
  base_inpaint_model_path: str
17
  is_sdxl: bool = False
18
+ base_dimension: int = 512
19
 
20
 
21
  def load_model_from_config(path):
 
26
  model_path = config.get("model_path", path)
27
  inpaint_model_path = config.get("inpaint_model_path", path)
28
  is_sdxl = config.get("is_sdxl", False)
29
+ base_dimension = config.get("base_dimension", 512)
30
 
31
  m_config.base_model_path = model_path
32
  m_config.base_inpaint_model_path = inpaint_model_path
33
  m_config.is_sdxl = is_sdxl
34
+ m_config.base_dimension = base_dimension
35
 
36
  #
37
  # if config.get("model_type") == "huggingface":