|
|
|
from config.config import GenerationConfig, ModelConfig |
|
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
from datetime import datetime |
|
import logging |
|
import torch |
|
|
|
from config.config import settings |
|
|
|
from services.prompt_builder import LlamaPromptTemplate |
|
from services.model_manager import ModelManager |
|
|
|
from services.base_generator import BaseGenerator |
|
|
|
from services.strategy import DefaultStrategy, MajorityVotingStrategy, BestOfN, BeamSearch, DVT, COT, ReAct |
|
|
|
import asyncio |
|
from io import StringIO |
|
import pandas as pd |
|
|
|
from langfuse.decorators import observe, langfuse_context |
|
import os |
|
|
|
|
|
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae" |
|
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af" |
|
os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" |
|
|
|
try: |
|
langfuse = Langfuse() |
|
except Exception as e: |
|
print("Langfuse Offline") |
|
|
|
|
|
|
|
@observe() |
|
class LlamaGenerator(BaseGenerator): |
|
def __init__( |
|
self, |
|
llama_model_name: str, |
|
prm_model_path: str, |
|
device: Optional[str] = None, |
|
default_generation_config: Optional[GenerationConfig] = None, |
|
model_config: Optional[ModelConfig] = None, |
|
cache_size: int = 1000, |
|
max_batch_size: int = 32, |
|
|
|
|
|
|
|
): |
|
print(llama_model_name) |
|
print(prm_model_path) |
|
|
|
self.model_manager = ModelManager() |
|
|
|
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.tokenizer = self.model_manager.load_tokenizer(llama_model_name) |
|
|
|
super().__init__( |
|
llama_model_name, |
|
device, |
|
default_generation_config, |
|
model_config, |
|
cache_size, |
|
max_batch_size |
|
) |
|
|
|
|
|
self.model_manager.load_model( |
|
"llama", |
|
llama_model_name, |
|
"llama", |
|
self.model_config |
|
) |
|
self.model_manager.load_model( |
|
"prm", |
|
prm_model_path, |
|
"gguf", |
|
self.model_config |
|
) |
|
|
|
|
|
self.model = self.model_manager.models.get("llama") |
|
if not self.model: |
|
raise ValueError(f"Failed to load model: {llama_model_name}") |
|
|
|
self.prm_model = self.model_manager.models.get("prm") |
|
|
|
|
|
|
|
self.prompt_builder = LlamaPromptTemplate() |
|
self._init_strategies() |
|
|
|
def _init_strategies(self): |
|
self.strategies = { |
|
"default": DefaultStrategy(), |
|
"majority_voting": MajorityVotingStrategy(), |
|
"best_of_n": BestOfN(), |
|
"beam_search": BeamSearch(), |
|
"dvts": DVT(), |
|
} |
|
|
|
def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: |
|
"""Get generation kwargs based on config.""" |
|
return { |
|
key: getattr(config, key) |
|
for key in [ |
|
"max_new_tokens", |
|
"temperature", |
|
"top_p", |
|
"top_k", |
|
"repetition_penalty", |
|
"length_penalty", |
|
"do_sample" |
|
] |
|
if hasattr(config, key) |
|
} |
|
|
|
@observe() |
|
def generate_stream (self): |
|
return " NOt implememnted yet " |
|
|
|
@observe() |
|
def generate( |
|
self, |
|
prompt: str, |
|
model_kwargs: Dict[str, Any], |
|
strategy: str = "default", |
|
**kwargs |
|
) -> str: |
|
""" |
|
Generate text based on a given strategy. |
|
|
|
Args: |
|
prompt (str): Input prompt for text generation. |
|
model_kwargs (Dict[str, Any]): Additional arguments for model generation. |
|
strategy (str): The generation strategy to use (default: "default"). |
|
**kwargs: Additional arguments passed to the strategy. |
|
|
|
Returns: |
|
str: Generated text response. |
|
|
|
Raises: |
|
ValueError: If the specified strategy is not available. |
|
""" |
|
|
|
if strategy not in self.strategies: |
|
raise ValueError(f"Unknown strategy: {strategy}. Available strategies are: {list(self.strategies.keys())}") |
|
|
|
|
|
kwargs.pop("generator", None) |
|
|
|
|
|
return self.strategies[strategy].generate( |
|
generator=self, |
|
prompt=prompt, |
|
model_kwargs=model_kwargs, |
|
**kwargs |
|
) |
|
|
|
@observe() |
|
def generate_with_context( |
|
self, |
|
context: str, |
|
user_input: str, |
|
chat_history: List[Tuple[str, str]], |
|
model_kwargs: Dict[str, Any], |
|
max_history_turns: int = 3, |
|
strategy: str = "default", |
|
num_samples: int = 5, |
|
depth: int = 3, |
|
breadth: int = 2, |
|
|
|
) -> str: |
|
"""Generate a response using context and chat history. |
|
|
|
Args: |
|
context (str): Context for the conversation |
|
user_input (str): Current user input |
|
chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs |
|
model_kwargs (dict): Additional arguments for model.generate() |
|
max_history_turns (int): Maximum number of history turns to include |
|
strategy (str): Generation strategy |
|
num_samples (int): Number of samples for applicable strategies |
|
depth (int): Depth for DVTS strategy |
|
breadth (int): Breadth for DVTS strategy |
|
|
|
Returns: |
|
str: Generated response |
|
""" |
|
prompt = self.prompt_builder.format( |
|
context, |
|
user_input, |
|
chat_history, |
|
max_history_turns |
|
) |
|
return self.generate( |
|
generator=self, |
|
prompt=prompt, |
|
model_kwargs=model_kwargs, |
|
strategy=strategy, |
|
num_samples=num_samples, |
|
depth=depth, |
|
breadth=breadth |
|
) |
|
|
|
|
|
|
|
def check_health(self) : |
|
"""Check the health status of the generator.""" |
|
|
|
return "All good? - Check not omplemented " |