Maharshi Gor
Bugfix `logprob` in workflows.
d0ae1a9
# %%
import json
import os
from typing import Any, Optional
import cohere
import numpy as np
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_core.language_models import BaseChatModel
from langchain_openai import ChatOpenAI
from loguru import logger
from openai import OpenAI
from pydantic import BaseModel, Field
from rich import print as rprint
from .configs import AVAILABLE_MODELS
from .llmcache import LLMCache
# Initialize global cache
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache")
def _openai_is_json_mode_supported(model_name: str) -> bool:
if model_name.startswith("gpt-4"):
return True
if model_name.startswith("gpt-3.5"):
return False
logger.warning(f"OpenAI model {model_name} is not available in this app, skipping JSON mode, returning False")
return False
class LLMOutput(BaseModel):
content: str = Field(description="The content of the response")
logprob: Optional[float] = Field(None, description="The log probability of the response")
def _get_langchain_chat_output(llm: BaseChatModel, system: str, prompt: str) -> str:
output = llm.invoke([("system", system), ("human", prompt)])
ai_message = output["raw"]
content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
content_str = json.dumps(content)
return {"content": content_str, "output": output["parsed"].model_dump()}
def _cohere_completion(
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
) -> str:
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
]
client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
response = client.chat(
model=model,
messages=messages,
response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
logprobs=logprobs,
temperature=temperature,
)
output = {}
output["content"] = response.message.content[0].text
output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump()
if logprobs:
output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs)
output["prob"] = np.exp(output["logprob"])
return output
def _openai_langchain_completion(
model: str, system: str, prompt: str, response_model, temperature: float | None = None
) -> str:
llm = ChatOpenAI(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
return _get_langchain_chat_output(llm, system, prompt)
def _openai_completion(
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
) -> str:
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
]
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
response = client.beta.chat.completions.parse(
model=model,
messages=messages,
response_format=response_model,
logprobs=logprobs,
temperature=temperature,
)
output = {}
output["content"] = response.choices[0].message.content
output["output"] = response.choices[0].message.parsed.model_dump()
if logprobs:
output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content)
output["prob"] = np.exp(output["logprob"])
return output
def _anthropic_completion(
model: str, system: str, prompt: str, response_model, temperature: float | None = None
) -> str:
llm = ChatAnthropic(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
return _get_langchain_chat_output(llm, system, prompt)
def _llm_completion(
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
) -> dict[str, Any]:
"""
Generate a completion from an LLM provider with structured output without caching.
Args:
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
system (str): System prompt/instructions for the model
prompt (str): User prompt/input
response_format: Pydantic model defining the expected response structure
logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
Note: Not supported by Anthropic models.
Returns:
dict: Contains:
- output: The structured response matching response_format
- logprob: (optional) Sum of log probabilities if logprobs=True
- prob: (optional) Exponential of logprob if logprobs=True
Raises:
ValueError: If logprobs=True with Anthropic models
"""
model_name = AVAILABLE_MODELS[model]["model"]
provider = model.split("/")[0]
if provider == "Cohere":
return _cohere_completion(model_name, system, prompt, response_format, temperature, logprobs)
elif provider == "OpenAI":
if _openai_is_json_mode_supported(model_name):
return _openai_completion(model_name, system, prompt, response_format, temperature, logprobs)
elif logprobs:
raise ValueError(f"{model} does not support logprobs feature.")
else:
return _openai_langchain_completion(model_name, system, prompt, response_format, temperature)
elif provider == "Anthropic":
if logprobs:
raise ValueError("Anthropic models do not support logprobs")
return _anthropic_completion(model_name, system, prompt, response_format, temperature)
else:
raise ValueError(f"Provider {provider} not supported")
def completion(
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
) -> dict[str, Any]:
"""
Generate a completion from an LLM provider with structured output with caching.
Args:
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
system (str): System prompt/instructions for the model
prompt (str): User prompt/input
response_format: Pydantic model defining the expected response structure
logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
Note: Not supported by Anthropic models.
Returns:
dict: Contains:
- output: The structured response matching response_format
- logprob: (optional) Sum of log probabilities if logprobs=True
- prob: (optional) Exponential of logprob if logprobs=True
Raises:
ValueError: If logprobs=True with Anthropic models
"""
if model not in AVAILABLE_MODELS:
raise ValueError(f"Model {model} not supported")
if logprobs and not AVAILABLE_MODELS[model].get("logprobs", False):
logger.warning(f"{model} does not support logprobs feature, setting logprobs to False")
logprobs = False
# Check cache first
cached_response = llm_cache.get(model, system, prompt, response_format, temperature)
if cached_response and (not logprobs or cached_response.get("logprob")):
logger.info(f"Cache hit for model {model}")
return cached_response
logger.info(f"Cache miss for model {model}, calling API. Logprobs: {logprobs}")
# Continue with the original implementation for cache miss
response = _llm_completion(model, system, prompt, response_format, temperature, logprobs)
# Update cache with the new response
llm_cache.set(
model,
system,
prompt,
response_format,
temperature,
response,
)
return response
# %%
if __name__ == "__main__":
from tqdm import tqdm
class ExplainedAnswer(BaseModel):
"""
The answer to the question and a terse explanation of the answer.
"""
answer: str = Field(description="The short answer to the question")
explanation: str = Field(description="5 words terse best explanation of the answer.")
models = list(AVAILABLE_MODELS.keys())[:1] # Just use the first model for testing
system = "You are an accurate and concise explainer of scientific concepts."
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache", reset=True)
# First call - should be a cache miss
logger.info("First call - should be a cache miss")
for model in tqdm(models):
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
rprint(response)
# Second call - should be a cache hit
logger.info("Second call - should be a cache hit")
for model in tqdm(models):
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
rprint(response)
# Slightly different prompt - should be a cache miss
logger.info("Different prompt - should be a cache miss")
prompt2 = "Which planet is closest to the sun? Answer directly."
for model in tqdm(models):
response = completion(model, system, prompt2, ExplainedAnswer, logprobs=False)
rprint(response)
# Get cache entries count from SQLite
try:
cache_entries = llm_cache.get_all_entries()
logger.info(f"Cache now has {len(cache_entries)} items")
except Exception as e:
logger.error(f"Failed to get cache entries: {e}")
# Test adding entry with temperature parameter
logger.info("Testing with temperature parameter")
response = completion(models[0], system, "What is Mars?", ExplainedAnswer, temperature=0.7, logprobs=False)
rprint(response)
# Demonstrate forced sync to HF if repo is configured
if llm_cache.hf_repo_id:
logger.info("Forcing sync to HF dataset")
try:
llm_cache.sync_to_hf()
logger.info("Successfully synced to HF dataset")
except Exception as e:
logger.exception(f"Failed to sync to HF: {e}")
else:
logger.info("HF repo not configured, skipping sync test")
# %%