llm / services /llama_generator.py
Chris4K's picture
Update services/llama_generator.py
c299bfb verified
# llama_generator.py
from config.config import GenerationConfig, ModelConfig
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import logging
import torch
from config.config import settings
from services.prompt_builder import LlamaPromptTemplate
from services.model_manager import ModelManager
from services.base_generator import BaseGenerator
from services.strategy import DefaultStrategy, MajorityVotingStrategy, BestOfN, BeamSearch, DVT, COT, ReAct
import asyncio
from io import StringIO
import pandas as pd
from langfuse.decorators import observe, langfuse_context
import os
# Initialize Langfuse
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae"
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af"
os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" # 🇪🇺 EU region
try:
langfuse = Langfuse()
except Exception as e:
print("Langfuse Offline")
@observe()
class LlamaGenerator(BaseGenerator):
def __init__(
self,
llama_model_name: str,
prm_model_path: str,
device: Optional[str] = None,
default_generation_config: Optional[GenerationConfig] = None,
model_config: Optional[ModelConfig] = None,
cache_size: int = 1000,
max_batch_size: int = 32,
# self.tokenizer = self.load_tokenizer(llama_model_name)
# self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
):
print(llama_model_name)
print(prm_model_path)
self.model_manager = ModelManager()
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = self.model_manager.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer
super().__init__(
llama_model_name,
device,
default_generation_config,
model_config,
cache_size,
max_batch_size
)
# Initialize models
self.model_manager.load_model(
"llama",
llama_model_name,
"llama",
self.model_config
)
self.model_manager.load_model(
"prm",
prm_model_path,
"gguf",
self.model_config
)
# Assign llama model to self.model
self.model = self.model_manager.models.get("llama")
if not self.model:
raise ValueError(f"Failed to load model: {llama_model_name}")
self.prm_model = self.model_manager.models.get("prm")
#self.prm_tokenizer = self.model_manager.load_tokenizer(prm_model_path) # Add this line to initialize the tokenizer
self.prompt_builder = LlamaPromptTemplate()
self._init_strategies()
def _init_strategies(self):
self.strategies = {
"default": DefaultStrategy(),
"majority_voting": MajorityVotingStrategy(),
"best_of_n": BestOfN(),
"beam_search": BeamSearch(),
"dvts": DVT(),
}
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]:
"""Get generation kwargs based on config."""
return {
key: getattr(config, key)
for key in [
"max_new_tokens",
"temperature",
"top_p",
"top_k",
"repetition_penalty",
"length_penalty",
"do_sample"
]
if hasattr(config, key)
}
@observe()
def generate_stream (self):
return " NOt implememnted yet "
@observe()
def generate(
self,
prompt: str,
model_kwargs: Dict[str, Any],
strategy: str = "default",
**kwargs
) -> str:
"""
Generate text based on a given strategy.
Args:
prompt (str): Input prompt for text generation.
model_kwargs (Dict[str, Any]): Additional arguments for model generation.
strategy (str): The generation strategy to use (default: "default").
**kwargs: Additional arguments passed to the strategy.
Returns:
str: Generated text response.
Raises:
ValueError: If the specified strategy is not available.
"""
# Validate that the strategy exists
if strategy not in self.strategies:
raise ValueError(f"Unknown strategy: {strategy}. Available strategies are: {list(self.strategies.keys())}")
# Extract `generator` from kwargs if it exists to prevent duplication
kwargs.pop("generator", None)
# Call the selected strategy with the provided arguments
return self.strategies[strategy].generate(
generator=self, # The generator instance
prompt=prompt, # The input prompt
model_kwargs=model_kwargs, # Arguments for the model
**kwargs # Any additional strategy-specific arguments
)
@observe()
def generate_with_context(
self,
context: str,
user_input: str,
chat_history: List[Tuple[str, str]],
model_kwargs: Dict[str, Any],
max_history_turns: int = 3,
strategy: str = "default",
num_samples: int = 5,
depth: int = 3,
breadth: int = 2,
) -> str:
"""Generate a response using context and chat history.
Args:
context (str): Context for the conversation
user_input (str): Current user input
chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs
model_kwargs (dict): Additional arguments for model.generate()
max_history_turns (int): Maximum number of history turns to include
strategy (str): Generation strategy
num_samples (int): Number of samples for applicable strategies
depth (int): Depth for DVTS strategy
breadth (int): Breadth for DVTS strategy
Returns:
str: Generated response
"""
prompt = self.prompt_builder.format(
context,
user_input,
chat_history,
max_history_turns
)
return self.generate(
generator=self,
prompt=prompt,
model_kwargs=model_kwargs,
strategy=strategy,
num_samples=num_samples,
depth=depth,
breadth=breadth
)
def check_health(self) : #-> HealthStatus:
"""Check the health status of the generator."""
#return self.health_check.check_system_resources() # TODO add model status
return "All good? - Check not omplemented "