# %% import json import os from typing import Optional import cohere import json_repair import numpy as np from anthropic import Anthropic from langchain_anthropic import ChatAnthropic from langchain_cohere import ChatCohere from langchain_openai import ChatOpenAI from openai import OpenAI from pydantic import BaseModel, Field from rich import print as rprint import utils from envs import AVAILABLE_MODELS class LLMOutput(BaseModel): content: str = Field(description="The content of the response") logprob: Optional[float] = Field(None, description="The log probability of the response") def completion(model: str, system: str, prompt: str, response_format, logprobs: bool = False) -> str: """ Generate a completion from an LLM provider with structured output. Args: model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4") system (str): System prompt/instructions for the model prompt (str): User prompt/input response_format: Pydantic model defining the expected response structure logprobs (bool, optional): Whether to return log probabilities. Defaults to False. Note: Not supported by Anthropic models. Returns: dict: Contains: - output: The structured response matching response_format - logprob: (optional) Sum of log probabilities if logprobs=True - prob: (optional) Exponential of logprob if logprobs=True Raises: ValueError: If logprobs=True with Anthropic models """ if model not in AVAILABLE_MODELS: raise ValueError(f"Model {model} not supported") model_name = AVAILABLE_MODELS[model]["model"] provider = model.split("/")[0] if provider == "Cohere": return _cohere_completion(model_name, system, prompt, response_format, logprobs) elif provider == "OpenAI": return _openai_completion(model_name, system, prompt, response_format, logprobs) elif provider == "Anthropic": if logprobs: raise ValueError("Anthropic does not support logprobs") return _anthropic_completion(model_name, system, prompt, response_format) else: raise ValueError(f"Provider {provider} not supported") def _cohere_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str: messages = [ {"role": "system", "content": system}, {"role": "user", "content": prompt}, ] client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY")) response = client.chat( model=model, messages=messages, response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()}, logprobs=logprobs, ) output = {} output["content"] = response.message.content[0].text output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump() if logprobs: output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs) output["prob"] = np.exp(output["logprob"]) return output def _openai_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str: messages = [ {"role": "system", "content": system}, {"role": "user", "content": prompt}, ] client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) response = client.beta.chat.completions.parse( model=model, messages=messages, response_format=response_model, logprobs=logprobs, ) output = {} output["content"] = response.choices[0].message.content output["output"] = response.choices[0].message.parsed.model_dump() if logprobs: output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content) output["prob"] = np.exp(output["logprob"]) return output def _anthropic_completion(model: str, system: str, prompt: str, response_model) -> str: llm = ChatAnthropic(model=model).with_structured_output(response_model, include_raw=True) output = llm.invoke([("system", system), ("human", prompt)]) return {"content": output.raw, "output": output.parsed.model_dump()} if __name__ == "__main__": class ExplainedAnswer(BaseModel): """ The answer to the question and a terse explanation of the answer. """ answer: str = Field(description="The short answer to the question") explanation: str = Field(description="5 words terse best explanation of the answer.") model = "Anthropic/claude-3-5-sonnet-20240620" system = "You are an accurate and concise explainer of scientific concepts." prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed." # response = _cohere_completion("command-r", system, prompt, ExplainedAnswer, logprobs=True) response = completion(model, system, prompt, ExplainedAnswer, logprobs=False) rprint(response) # %%