nsfwalex commited on
Commit
69fb4e6
·
verified ·
1 Parent(s): 7df1bb1

get model configs from local file instead of urls

Browse files
Files changed (1) hide show
  1. inference_manager.py +22 -18
inference_manager.py CHANGED
@@ -259,36 +259,43 @@ class InferenceManager:
259
  torch.cuda.empty_cache()
260
  print("Memory released and cache cleared.")
261
 
 
262
  class ModelManager:
263
- def __init__(self, config_urls):
264
  """
265
- Initialize the ModelManager by loading all models specified by the URLs.
266
 
267
- :param config_urls: List of URLs pointing to model config files (e.g., ["model1/config.json", "model2/config.json"]).
268
  """
269
  self.models = {}
270
- self.load_models(config_urls)
 
271
 
272
- def load_models(self, config_urls):
273
  """
274
- Load and initialize InferenceManager instances for each config URL.
275
 
276
- :param config_urls: List of config file URLs.
277
  """
278
- for url in config_urls:
279
- model_name = self.get_model_name_from_url(url)
280
- print(f"Initializing model: {model_name} from {url}")
 
 
 
 
 
281
  try:
282
  # Initialize InferenceManager for each model
283
- self.models[model_name] = InferenceManager(config_path=url)
284
  except Exception as e:
285
- print(f"Failed to initialize model {model_name} from {url}: {e}")
286
 
287
  def get_model_name_from_url(self, url):
288
  """
289
- Extract the model name from the config URL (filename without extension).
290
 
291
- :param url: The URL of the configuration file.
292
  :return: The model name (file name without extension).
293
  """
294
  filename = os.path.basename(url)
@@ -331,10 +338,7 @@ class ModelManager:
331
  model.release(model.base_model_pipeline)
332
  except Exception as e:
333
  print(f"Failed to release model {model_id}: {e}")
334
-
335
-
336
-
337
-
338
  # Hugging Face file download function - returns only file path
339
  def download_from_hf(filename, local_dir=None):
340
  try:
 
259
  torch.cuda.empty_cache()
260
  print("Memory released and cache cleared.")
261
 
262
+
263
  class ModelManager:
264
+ def __init__(self, model_directory):
265
  """
266
+ Initialize the ModelManager by scanning all `.model.json` files in the given directory.
267
 
268
+ :param model_directory: The directory to scan for model config files (e.g., "/path/to/models").
269
  """
270
  self.models = {}
271
+ self.model_directory = model_directory
272
+ self.load_models()
273
 
274
+ def load_models(self):
275
  """
276
+ Scan the model directory for `.model.json` files and initialize InferenceManager instances for each one.
277
 
278
+ :param model_directory: Directory to scan for `.model.json` files.
279
  """
280
+ model_files = glob.glob(os.path.join(self.model_directory, "*.model.json"))
281
+ if not model_files:
282
+ print(f"No model configuration files found in {self.model_directory}")
283
+ return
284
+
285
+ for file_path in model_files:
286
+ model_name = self.get_model_name_from_url(file_path)
287
+ print(f"Initializing model: {model_name} from {file_path}")
288
  try:
289
  # Initialize InferenceManager for each model
290
+ self.models[model_name] = InferenceManager(config_path=file_path)
291
  except Exception as e:
292
+ print(f"Failed to initialize model {model_name} from {file_path}: {e}")
293
 
294
  def get_model_name_from_url(self, url):
295
  """
296
+ Extract the model name from the config file path (filename without extension).
297
 
298
+ :param url: The file path of the configuration file.
299
  :return: The model name (file name without extension).
300
  """
301
  filename = os.path.basename(url)
 
338
  model.release(model.base_model_pipeline)
339
  except Exception as e:
340
  print(f"Failed to release model {model_id}: {e}")
341
+
 
 
 
342
  # Hugging Face file download function - returns only file path
343
  def download_from_hf(filename, local_dir=None):
344
  try: