RWKV-LatestSpace / config.py
sparkleman
UPDATE: change default model load workflow
50f89e3
raw
history blame
3.25 kB
from pydantic import BaseModel, Field
from typing import List, Optional
from typing import List, Optional, Union, Any
import sys
from pydantic_settings import BaseSettings
class CliConfig(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
CONFIG_FILE: str = Field("./config.local.yaml", description="Config file path")
CLI_CONFIG = CliConfig()
class SamplerConfig(BaseModel):
"""Default sampler configuration for each model."""
max_tokens: int = Field(512, description="Maximum number of tokens to generate.")
temperature: float = Field(1.0, description="Sampling temperature.")
top_p: float = Field(0.3, description="Top-p sampling threshold.")
presence_penalty: float = Field(0.5, description="Presence penalty.")
count_penalty: float = Field(0.5, description="Count penalty.")
penalty_decay: float = Field(0.996, description="Penalty decay factor.")
stop: List[str] = Field(["\n\n"], description="List of stop sequences.")
stop_tokens: List[int] = Field([0], description="List of stop tokens.")
class ModelConfig(BaseModel):
"""Configuration for each individual model."""
SERVICE_NAME: str = Field(..., description="Service name of the model.")
MODEL_FILE_PATH: Optional[str] = Field(None, description="Model file path.")
DOWNLOAD_MODEL_FILE_NAME: Optional[str] = Field(
None, description="Model name, should end with .pth"
)
DOWNLOAD_MODEL_REPO_ID: Optional[str] = Field(
None, description="Model repository ID on Hugging Face Hub."
)
DOWNLOAD_MODEL_DIR: Optional[str] = Field(
None, description="Directory to download the model to."
)
REASONING: bool = Field(
False, description="Whether reasoning is enabled for this model."
)
DEFAULT_CHAT: bool = Field(False, description="Whether this model is the default chat model.")
DEFAULT_REASONING: bool = Field(False, description="Whether this model is the default reasoning model.")
DEFAULT_SAMPLER: SamplerConfig = Field(
SamplerConfig(), description="Default sampler configuration for this model."
)
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
class RootConfig(BaseModel):
"""Root configuration for the RWKV service."""
HOST: Optional[str] = Field(
"127.0.0.1", description="Host IP address to bind to."
) # 注释掉可选的HOST和PORT
PORT: Optional[int] = Field(
8000, description="Port number to listen on."
) # 因为YAML示例中被注释掉了
STRATEGY: str = Field(
"cpu", description="Strategy for model execution (e.g., 'cuda fp16')."
)
RWKV_CUDA_ON: bool = Field(False, description="Whether to enable RWKV CUDA kernel.")
CHUNK_LEN: int = Field(256, description="Chunk length for processing.")
MODELS: List[ModelConfig] = Field(..., description="List of model configurations.")
import yaml
try:
with open(CLI_CONFIG.CONFIG_FILE, "r", encoding="utf-8") as f:
CONFIG = RootConfig.model_validate(yaml.safe_load(f.read()))
except Exception as e:
print(f"Pydantic Model Validation Failed: {e}")
sys.exit(0)