|
|
|
from abc import ABC, abstractmethod |
|
from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple |
|
from dataclasses import dataclass |
|
from logging import getLogger |
|
from services.model_manager import ModelManager |
|
from services.cache import ResponseCache |
|
from services.batch_processor import BatchProcessor |
|
from services.health_check import HealthCheck |
|
|
|
from config.config import GenerationConfig, ModelConfig |
|
|
|
class BaseGenerator(ABC): |
|
"""Base class for all generator implementations.""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str, |
|
device: Optional[str] = None, |
|
default_generation_config: Optional[GenerationConfig] = None, |
|
model_config: Optional[ModelConfig] = None, |
|
cache_size: int = 1000, |
|
max_batch_size: int = 32 |
|
): |
|
self.logger = getLogger(__name__) |
|
self.model_manager = ModelManager(device) |
|
self.cache = ResponseCache(cache_size) |
|
self.batch_processor = BatchProcessor(max_batch_size) |
|
self.health_check = HealthCheck() |
|
|
|
|
|
self.default_config = default_generation_config or GenerationConfig() |
|
self.model_config = model_config or ModelConfig() |
|
|
|
@abstractmethod |
|
async def generate_stream( |
|
self, |
|
prompt: str, |
|
config: Optional[GenerationConfig] = None |
|
) -> AsyncGenerator[str, None]: |
|
pass |
|
|
|
@abstractmethod |
|
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: |
|
pass |
|
|
|
@abstractmethod |
|
def generate( |
|
self, |
|
prompt: str, |
|
model_kwargs: Dict[str, Any], |
|
strategy: str = "default", |
|
**kwargs |
|
) -> str: |
|
pass |