File size: 10,274 Bytes
973519b 3a1af80 973519b 3a1af80 973519b 3a1af80 973519b 2900a81 973519b 0bab47c 3a1af80 973519b 2900a81 973519b 3a1af80 2900a81 973519b 3a1af80 973519b 3a1af80 973519b 3a1af80 2900a81 3a1af80 973519b 3a1af80 973519b 3a1af80 2900a81 3a1af80 2900a81 3a1af80 2900a81 973519b 2900a81 3a1af80 2900a81 3a1af80 2900a81 3a1af80 2900a81 3a1af80 2900a81 973519b 2900a81 3a1af80 d0ae1a9 3a1af80 d0ae1a9 3a1af80 d0ae1a9 3a1af80 2900a81 973519b 2900a81 973519b 3a1af80 973519b 3a1af80 2900a81 973519b 3a1af80 973519b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
# %%
import json
import os
from typing import Any, Optional
import cohere
import numpy as np
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_core.language_models import BaseChatModel
from langchain_openai import ChatOpenAI
from loguru import logger
from openai import OpenAI
from pydantic import BaseModel, Field
from rich import print as rprint
from .configs import AVAILABLE_MODELS
from .llmcache import LLMCache
# Initialize global cache
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache")
def _openai_is_json_mode_supported(model_name: str) -> bool:
if model_name.startswith("gpt-4"):
return True
if model_name.startswith("gpt-3.5"):
return False
logger.warning(f"OpenAI model {model_name} is not available in this app, skipping JSON mode, returning False")
return False
class LLMOutput(BaseModel):
content: str = Field(description="The content of the response")
logprob: Optional[float] = Field(None, description="The log probability of the response")
def _get_langchain_chat_output(llm: BaseChatModel, system: str, prompt: str) -> str:
output = llm.invoke([("system", system), ("human", prompt)])
ai_message = output["raw"]
content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
content_str = json.dumps(content)
return {"content": content_str, "output": output["parsed"].model_dump()}
def _cohere_completion(
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
) -> str:
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
]
client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
response = client.chat(
model=model,
messages=messages,
response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
logprobs=logprobs,
temperature=temperature,
)
output = {}
output["content"] = response.message.content[0].text
output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump()
if logprobs:
output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs)
output["prob"] = np.exp(output["logprob"])
return output
def _openai_langchain_completion(
model: str, system: str, prompt: str, response_model, temperature: float | None = None
) -> str:
llm = ChatOpenAI(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
return _get_langchain_chat_output(llm, system, prompt)
def _openai_completion(
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
) -> str:
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
]
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
response = client.beta.chat.completions.parse(
model=model,
messages=messages,
response_format=response_model,
logprobs=logprobs,
temperature=temperature,
)
output = {}
output["content"] = response.choices[0].message.content
output["output"] = response.choices[0].message.parsed.model_dump()
if logprobs:
output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content)
output["prob"] = np.exp(output["logprob"])
return output
def _anthropic_completion(
model: str, system: str, prompt: str, response_model, temperature: float | None = None
) -> str:
llm = ChatAnthropic(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
return _get_langchain_chat_output(llm, system, prompt)
def _llm_completion(
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
) -> dict[str, Any]:
"""
Generate a completion from an LLM provider with structured output without caching.
Args:
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
system (str): System prompt/instructions for the model
prompt (str): User prompt/input
response_format: Pydantic model defining the expected response structure
logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
Note: Not supported by Anthropic models.
Returns:
dict: Contains:
- output: The structured response matching response_format
- logprob: (optional) Sum of log probabilities if logprobs=True
- prob: (optional) Exponential of logprob if logprobs=True
Raises:
ValueError: If logprobs=True with Anthropic models
"""
model_name = AVAILABLE_MODELS[model]["model"]
provider = model.split("/")[0]
if provider == "Cohere":
return _cohere_completion(model_name, system, prompt, response_format, temperature, logprobs)
elif provider == "OpenAI":
if _openai_is_json_mode_supported(model_name):
return _openai_completion(model_name, system, prompt, response_format, temperature, logprobs)
elif logprobs:
raise ValueError(f"{model} does not support logprobs feature.")
else:
return _openai_langchain_completion(model_name, system, prompt, response_format, temperature)
elif provider == "Anthropic":
if logprobs:
raise ValueError("Anthropic models do not support logprobs")
return _anthropic_completion(model_name, system, prompt, response_format, temperature)
else:
raise ValueError(f"Provider {provider} not supported")
def completion(
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
) -> dict[str, Any]:
"""
Generate a completion from an LLM provider with structured output with caching.
Args:
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
system (str): System prompt/instructions for the model
prompt (str): User prompt/input
response_format: Pydantic model defining the expected response structure
logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
Note: Not supported by Anthropic models.
Returns:
dict: Contains:
- output: The structured response matching response_format
- logprob: (optional) Sum of log probabilities if logprobs=True
- prob: (optional) Exponential of logprob if logprobs=True
Raises:
ValueError: If logprobs=True with Anthropic models
"""
if model not in AVAILABLE_MODELS:
raise ValueError(f"Model {model} not supported")
if logprobs and not AVAILABLE_MODELS[model].get("logprobs", False):
logger.warning(f"{model} does not support logprobs feature, setting logprobs to False")
logprobs = False
# Check cache first
cached_response = llm_cache.get(model, system, prompt, response_format, temperature)
if cached_response and (not logprobs or cached_response.get("logprob")):
logger.info(f"Cache hit for model {model}")
return cached_response
logger.info(f"Cache miss for model {model}, calling API. Logprobs: {logprobs}")
# Continue with the original implementation for cache miss
response = _llm_completion(model, system, prompt, response_format, temperature, logprobs)
# Update cache with the new response
llm_cache.set(
model,
system,
prompt,
response_format,
temperature,
response,
)
return response
# %%
if __name__ == "__main__":
from tqdm import tqdm
class ExplainedAnswer(BaseModel):
"""
The answer to the question and a terse explanation of the answer.
"""
answer: str = Field(description="The short answer to the question")
explanation: str = Field(description="5 words terse best explanation of the answer.")
models = list(AVAILABLE_MODELS.keys())[:1] # Just use the first model for testing
system = "You are an accurate and concise explainer of scientific concepts."
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache", reset=True)
# First call - should be a cache miss
logger.info("First call - should be a cache miss")
for model in tqdm(models):
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
rprint(response)
# Second call - should be a cache hit
logger.info("Second call - should be a cache hit")
for model in tqdm(models):
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
rprint(response)
# Slightly different prompt - should be a cache miss
logger.info("Different prompt - should be a cache miss")
prompt2 = "Which planet is closest to the sun? Answer directly."
for model in tqdm(models):
response = completion(model, system, prompt2, ExplainedAnswer, logprobs=False)
rprint(response)
# Get cache entries count from SQLite
try:
cache_entries = llm_cache.get_all_entries()
logger.info(f"Cache now has {len(cache_entries)} items")
except Exception as e:
logger.error(f"Failed to get cache entries: {e}")
# Test adding entry with temperature parameter
logger.info("Testing with temperature parameter")
response = completion(models[0], system, "What is Mars?", ExplainedAnswer, temperature=0.7, logprobs=False)
rprint(response)
# Demonstrate forced sync to HF if repo is configured
if llm_cache.hf_repo_id:
logger.info("Forcing sync to HF dataset")
try:
llm_cache.sync_to_hf()
logger.info("Successfully synced to HF dataset")
except Exception as e:
logger.exception(f"Failed to sync to HF: {e}")
else:
logger.info("HF repo not configured, skipping sync test")
# %%
|