llm / services /base_generator.py
Chris4K's picture
Update services/base_generator.py
54329ad verified
# base_generator.py
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple
from dataclasses import dataclass
from logging import getLogger
from services.model_manager import ModelManager
from services.cache import ResponseCache
from services.batch_processor import BatchProcessor
from services.health_check import HealthCheck
from config.config import GenerationConfig, ModelConfig
class BaseGenerator(ABC):
"""Base class for all generator implementations."""
def __init__(
self,
model_name: str,
device: Optional[str] = None,
default_generation_config: Optional[GenerationConfig] = None,
model_config: Optional[ModelConfig] = None,
cache_size: int = 1000,
max_batch_size: int = 32
):
self.logger = getLogger(__name__)
self.model_manager = ModelManager(device)
self.cache = ResponseCache(cache_size)
self.batch_processor = BatchProcessor(max_batch_size)
self.health_check = HealthCheck()
# self.tokenizer = self.model_manager.tokenizers[model_name]
#self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
self.default_config = default_generation_config or GenerationConfig()
self.model_config = model_config or ModelConfig()
@abstractmethod
async def generate_stream(
self,
prompt: str,
config: Optional[GenerationConfig] = None
) -> AsyncGenerator[str, None]:
pass
@abstractmethod
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
pass
@abstractmethod
def generate(
self,
prompt: str,
model_kwargs: Dict[str, Any],
strategy: str = "default",
**kwargs
) -> str:
pass