Maharshi Gor
Enhance model provider detection and add repository management script. Added support for multi step agent.
973519b
# Description: Utility functions for the model_step component. | |
from envs import AVAILABLE_MODELS, UNSELECTED_MODEL_NAME | |
def guess_model_provider(model_name: str): | |
"""Guess the provider of a model name.""" | |
model_name = model_name.lower() | |
if model_name.startswith("gpt-"): | |
return "OpenAI" | |
if "sonnet" in model_name or "claude" in model_name or "haiku" in model_name: | |
return "Anthropic" | |
if "command" in model_name: | |
return "Cohere" | |
raise ValueError(f"Model `{model_name}` not yet supported") | |
def get_model_and_provider(model_name: str): | |
"""Get the model and provider from a model name.""" | |
if model_name == UNSELECTED_MODEL_NAME: | |
return "", "" | |
splits = model_name.split("/", maxsplit=1) | |
if len(splits) == 1: | |
full_model_name = AVAILABLE_MODELS.get(model_name, model_name) | |
provider = guess_model_provider(full_model_name) | |
return full_model_name, provider | |
if len(splits) == 2: | |
provider, model_name = splits | |
full_model_name = AVAILABLE_MODELS.get(model_name, model_name) | |
return full_model_name, provider | |
raise ValueError(f"Model `{model_name}` not yet supported") | |
def get_full_model_name(model_name: str, provider: str = ""): | |
"""Get the full model name from a model name.""" | |
if model_name == "": | |
return UNSELECTED_MODEL_NAME | |
if not provider: | |
provider = guess_model_provider(model_name) | |
return f"{provider}/{model_name}" | |