File size: 4,990 Bytes
973519b 633b045 973519b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# %%
import json
import os
from typing import Optional
import cohere
import json_repair
import numpy as np
from anthropic import Anthropic
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_openai import ChatOpenAI
from openai import OpenAI
from pydantic import BaseModel, Field
from rich import print as rprint
import utils
from app_configs import AVAILABLE_MODELS
class LLMOutput(BaseModel):
content: str = Field(description="The content of the response")
logprob: Optional[float] = Field(None, description="The log probability of the response")
def completion(model: str, system: str, prompt: str, response_format, logprobs: bool = False) -> str:
"""
Generate a completion from an LLM provider with structured output.
Args:
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
system (str): System prompt/instructions for the model
prompt (str): User prompt/input
response_format: Pydantic model defining the expected response structure
logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
Note: Not supported by Anthropic models.
Returns:
dict: Contains:
- output: The structured response matching response_format
- logprob: (optional) Sum of log probabilities if logprobs=True
- prob: (optional) Exponential of logprob if logprobs=True
Raises:
ValueError: If logprobs=True with Anthropic models
"""
if model not in AVAILABLE_MODELS:
raise ValueError(f"Model {model} not supported")
model_name = AVAILABLE_MODELS[model]["model"]
provider = model.split("/")[0]
if provider == "Cohere":
return _cohere_completion(model_name, system, prompt, response_format, logprobs)
elif provider == "OpenAI":
return _openai_completion(model_name, system, prompt, response_format, logprobs)
elif provider == "Anthropic":
if logprobs:
raise ValueError("Anthropic does not support logprobs")
return _anthropic_completion(model_name, system, prompt, response_format)
else:
raise ValueError(f"Provider {provider} not supported")
def _cohere_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
]
client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
response = client.chat(
model=model,
messages=messages,
response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
logprobs=logprobs,
)
output = {}
output["content"] = response.message.content[0].text
output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump()
if logprobs:
output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs)
output["prob"] = np.exp(output["logprob"])
return output
def _openai_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
]
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
response = client.beta.chat.completions.parse(
model=model,
messages=messages,
response_format=response_model,
logprobs=logprobs,
)
output = {}
output["content"] = response.choices[0].message.content
output["output"] = response.choices[0].message.parsed.model_dump()
if logprobs:
output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content)
output["prob"] = np.exp(output["logprob"])
return output
def _anthropic_completion(model: str, system: str, prompt: str, response_model) -> str:
llm = ChatAnthropic(model=model).with_structured_output(response_model, include_raw=True)
output = llm.invoke([("system", system), ("human", prompt)])
return {"content": output.raw, "output": output.parsed.model_dump()}
if __name__ == "__main__":
class ExplainedAnswer(BaseModel):
"""
The answer to the question and a terse explanation of the answer.
"""
answer: str = Field(description="The short answer to the question")
explanation: str = Field(description="5 words terse best explanation of the answer.")
model = "Anthropic/claude-3-5-sonnet-20240620"
system = "You are an accurate and concise explainer of scientific concepts."
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
# response = _cohere_completion("command-r", system, prompt, ExplainedAnswer, logprobs=True)
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
rprint(response)
# %%
|