|
|
|
|
|
import json |
|
import os |
|
from typing import Any, Optional |
|
|
|
import cohere |
|
import numpy as np |
|
from langchain_anthropic import ChatAnthropic |
|
from langchain_cohere import ChatCohere |
|
from langchain_core.language_models import BaseChatModel |
|
from langchain_openai import ChatOpenAI |
|
from loguru import logger |
|
from openai import OpenAI |
|
from pydantic import BaseModel, Field |
|
from rich import print as rprint |
|
|
|
from .configs import AVAILABLE_MODELS |
|
from .llmcache import LLMCache |
|
|
|
|
|
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache") |
|
|
|
|
|
def _openai_is_json_mode_supported(model_name: str) -> bool: |
|
if model_name.startswith("gpt-4"): |
|
return True |
|
if model_name.startswith("gpt-3.5"): |
|
return False |
|
logger.warning(f"OpenAI model {model_name} is not available in this app, skipping JSON mode, returning False") |
|
return False |
|
|
|
|
|
class LLMOutput(BaseModel): |
|
content: str = Field(description="The content of the response") |
|
logprob: Optional[float] = Field(None, description="The log probability of the response") |
|
|
|
|
|
def _get_langchain_chat_output(llm: BaseChatModel, system: str, prompt: str) -> str: |
|
output = llm.invoke([("system", system), ("human", prompt)]) |
|
ai_message = output["raw"] |
|
content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls} |
|
content_str = json.dumps(content) |
|
return {"content": content_str, "output": output["parsed"].model_dump()} |
|
|
|
|
|
def _cohere_completion( |
|
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True |
|
) -> str: |
|
messages = [ |
|
{"role": "system", "content": system}, |
|
{"role": "user", "content": prompt}, |
|
] |
|
client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY")) |
|
response = client.chat( |
|
model=model, |
|
messages=messages, |
|
response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()}, |
|
logprobs=logprobs, |
|
temperature=temperature, |
|
) |
|
output = {} |
|
output["content"] = response.message.content[0].text |
|
output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump() |
|
if logprobs: |
|
output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs) |
|
output["prob"] = np.exp(output["logprob"]) |
|
return output |
|
|
|
|
|
def _openai_langchain_completion( |
|
model: str, system: str, prompt: str, response_model, temperature: float | None = None |
|
) -> str: |
|
llm = ChatOpenAI(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True) |
|
return _get_langchain_chat_output(llm, system, prompt) |
|
|
|
|
|
def _openai_completion( |
|
model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True |
|
) -> str: |
|
messages = [ |
|
{"role": "system", "content": system}, |
|
{"role": "user", "content": prompt}, |
|
] |
|
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
response = client.beta.chat.completions.parse( |
|
model=model, |
|
messages=messages, |
|
response_format=response_model, |
|
logprobs=logprobs, |
|
temperature=temperature, |
|
) |
|
output = {} |
|
output["content"] = response.choices[0].message.content |
|
output["output"] = response.choices[0].message.parsed.model_dump() |
|
if logprobs: |
|
output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content) |
|
output["prob"] = np.exp(output["logprob"]) |
|
return output |
|
|
|
|
|
def _anthropic_completion( |
|
model: str, system: str, prompt: str, response_model, temperature: float | None = None |
|
) -> str: |
|
llm = ChatAnthropic(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True) |
|
return _get_langchain_chat_output(llm, system, prompt) |
|
|
|
|
|
def _llm_completion( |
|
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False |
|
) -> dict[str, Any]: |
|
""" |
|
Generate a completion from an LLM provider with structured output without caching. |
|
|
|
Args: |
|
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4") |
|
system (str): System prompt/instructions for the model |
|
prompt (str): User prompt/input |
|
response_format: Pydantic model defining the expected response structure |
|
logprobs (bool, optional): Whether to return log probabilities. Defaults to False. |
|
Note: Not supported by Anthropic models. |
|
|
|
Returns: |
|
dict: Contains: |
|
- output: The structured response matching response_format |
|
- logprob: (optional) Sum of log probabilities if logprobs=True |
|
- prob: (optional) Exponential of logprob if logprobs=True |
|
|
|
Raises: |
|
ValueError: If logprobs=True with Anthropic models |
|
""" |
|
model_name = AVAILABLE_MODELS[model]["model"] |
|
provider = model.split("/")[0] |
|
if provider == "Cohere": |
|
return _cohere_completion(model_name, system, prompt, response_format, temperature, logprobs) |
|
elif provider == "OpenAI": |
|
if _openai_is_json_mode_supported(model_name): |
|
return _openai_completion(model_name, system, prompt, response_format, temperature, logprobs) |
|
elif logprobs: |
|
raise ValueError(f"{model} does not support logprobs feature.") |
|
else: |
|
return _openai_langchain_completion(model_name, system, prompt, response_format, temperature) |
|
elif provider == "Anthropic": |
|
if logprobs: |
|
raise ValueError("Anthropic models do not support logprobs") |
|
return _anthropic_completion(model_name, system, prompt, response_format, temperature) |
|
else: |
|
raise ValueError(f"Provider {provider} not supported") |
|
|
|
|
|
def completion( |
|
model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False |
|
) -> dict[str, Any]: |
|
""" |
|
Generate a completion from an LLM provider with structured output with caching. |
|
|
|
Args: |
|
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4") |
|
system (str): System prompt/instructions for the model |
|
prompt (str): User prompt/input |
|
response_format: Pydantic model defining the expected response structure |
|
logprobs (bool, optional): Whether to return log probabilities. Defaults to False. |
|
Note: Not supported by Anthropic models. |
|
|
|
Returns: |
|
dict: Contains: |
|
- output: The structured response matching response_format |
|
- logprob: (optional) Sum of log probabilities if logprobs=True |
|
- prob: (optional) Exponential of logprob if logprobs=True |
|
|
|
Raises: |
|
ValueError: If logprobs=True with Anthropic models |
|
""" |
|
if model not in AVAILABLE_MODELS: |
|
raise ValueError(f"Model {model} not supported") |
|
if logprobs and not AVAILABLE_MODELS[model].get("logprobs", False): |
|
logger.warning(f"{model} does not support logprobs feature, setting logprobs to False") |
|
logprobs = False |
|
|
|
|
|
cached_response = llm_cache.get(model, system, prompt, response_format, temperature) |
|
if cached_response and (not logprobs or cached_response.get("logprob")): |
|
logger.info(f"Cache hit for model {model}") |
|
return cached_response |
|
|
|
logger.info(f"Cache miss for model {model}, calling API. Logprobs: {logprobs}") |
|
|
|
|
|
response = _llm_completion(model, system, prompt, response_format, temperature, logprobs) |
|
|
|
|
|
llm_cache.set( |
|
model, |
|
system, |
|
prompt, |
|
response_format, |
|
temperature, |
|
response, |
|
) |
|
|
|
return response |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
from tqdm import tqdm |
|
|
|
class ExplainedAnswer(BaseModel): |
|
""" |
|
The answer to the question and a terse explanation of the answer. |
|
""" |
|
|
|
answer: str = Field(description="The short answer to the question") |
|
explanation: str = Field(description="5 words terse best explanation of the answer.") |
|
|
|
models = list(AVAILABLE_MODELS.keys())[:1] |
|
system = "You are an accurate and concise explainer of scientific concepts." |
|
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed." |
|
|
|
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache", reset=True) |
|
|
|
|
|
logger.info("First call - should be a cache miss") |
|
for model in tqdm(models): |
|
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False) |
|
rprint(response) |
|
|
|
|
|
logger.info("Second call - should be a cache hit") |
|
for model in tqdm(models): |
|
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False) |
|
rprint(response) |
|
|
|
|
|
logger.info("Different prompt - should be a cache miss") |
|
prompt2 = "Which planet is closest to the sun? Answer directly." |
|
for model in tqdm(models): |
|
response = completion(model, system, prompt2, ExplainedAnswer, logprobs=False) |
|
rprint(response) |
|
|
|
|
|
try: |
|
cache_entries = llm_cache.get_all_entries() |
|
logger.info(f"Cache now has {len(cache_entries)} items") |
|
except Exception as e: |
|
logger.error(f"Failed to get cache entries: {e}") |
|
|
|
|
|
logger.info("Testing with temperature parameter") |
|
response = completion(models[0], system, "What is Mars?", ExplainedAnswer, temperature=0.7, logprobs=False) |
|
rprint(response) |
|
|
|
|
|
if llm_cache.hf_repo_id: |
|
logger.info("Forcing sync to HF dataset") |
|
try: |
|
llm_cache.sync_to_hf() |
|
logger.info("Successfully synced to HF dataset") |
|
except Exception as e: |
|
logger.exception(f"Failed to sync to HF: {e}") |
|
else: |
|
logger.info("HF repo not configured, skipping sync test") |
|
|
|
|
|
|