get model configs from local file instead of urls
Browse files- inference_manager.py +22 -18
inference_manager.py
CHANGED
@@ -259,36 +259,43 @@ class InferenceManager:
|
|
259 |
torch.cuda.empty_cache()
|
260 |
print("Memory released and cache cleared.")
|
261 |
|
|
|
262 |
class ModelManager:
|
263 |
-
def __init__(self,
|
264 |
"""
|
265 |
-
Initialize the ModelManager by
|
266 |
|
267 |
-
:param
|
268 |
"""
|
269 |
self.models = {}
|
270 |
-
self.
|
|
|
271 |
|
272 |
-
def load_models(self
|
273 |
"""
|
274 |
-
|
275 |
|
276 |
-
:param
|
277 |
"""
|
278 |
-
|
279 |
-
|
280 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
281 |
try:
|
282 |
# Initialize InferenceManager for each model
|
283 |
-
self.models[model_name] = InferenceManager(config_path=
|
284 |
except Exception as e:
|
285 |
-
print(f"Failed to initialize model {model_name} from {
|
286 |
|
287 |
def get_model_name_from_url(self, url):
|
288 |
"""
|
289 |
-
Extract the model name from the config
|
290 |
|
291 |
-
:param url: The
|
292 |
:return: The model name (file name without extension).
|
293 |
"""
|
294 |
filename = os.path.basename(url)
|
@@ -331,10 +338,7 @@ class ModelManager:
|
|
331 |
model.release(model.base_model_pipeline)
|
332 |
except Exception as e:
|
333 |
print(f"Failed to release model {model_id}: {e}")
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
# Hugging Face file download function - returns only file path
|
339 |
def download_from_hf(filename, local_dir=None):
|
340 |
try:
|
|
|
259 |
torch.cuda.empty_cache()
|
260 |
print("Memory released and cache cleared.")
|
261 |
|
262 |
+
|
263 |
class ModelManager:
|
264 |
+
def __init__(self, model_directory):
|
265 |
"""
|
266 |
+
Initialize the ModelManager by scanning all `.model.json` files in the given directory.
|
267 |
|
268 |
+
:param model_directory: The directory to scan for model config files (e.g., "/path/to/models").
|
269 |
"""
|
270 |
self.models = {}
|
271 |
+
self.model_directory = model_directory
|
272 |
+
self.load_models()
|
273 |
|
274 |
+
def load_models(self):
|
275 |
"""
|
276 |
+
Scan the model directory for `.model.json` files and initialize InferenceManager instances for each one.
|
277 |
|
278 |
+
:param model_directory: Directory to scan for `.model.json` files.
|
279 |
"""
|
280 |
+
model_files = glob.glob(os.path.join(self.model_directory, "*.model.json"))
|
281 |
+
if not model_files:
|
282 |
+
print(f"No model configuration files found in {self.model_directory}")
|
283 |
+
return
|
284 |
+
|
285 |
+
for file_path in model_files:
|
286 |
+
model_name = self.get_model_name_from_url(file_path)
|
287 |
+
print(f"Initializing model: {model_name} from {file_path}")
|
288 |
try:
|
289 |
# Initialize InferenceManager for each model
|
290 |
+
self.models[model_name] = InferenceManager(config_path=file_path)
|
291 |
except Exception as e:
|
292 |
+
print(f"Failed to initialize model {model_name} from {file_path}: {e}")
|
293 |
|
294 |
def get_model_name_from_url(self, url):
|
295 |
"""
|
296 |
+
Extract the model name from the config file path (filename without extension).
|
297 |
|
298 |
+
:param url: The file path of the configuration file.
|
299 |
:return: The model name (file name without extension).
|
300 |
"""
|
301 |
filename = os.path.basename(url)
|
|
|
338 |
model.release(model.base_model_pipeline)
|
339 |
except Exception as e:
|
340 |
print(f"Failed to release model {model_id}: {e}")
|
341 |
+
|
|
|
|
|
|
|
342 |
# Hugging Face file download function - returns only file path
|
343 |
def download_from_hf(filename, local_dir=None):
|
344 |
try:
|