File size: 4,990 Bytes
973519b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633b045
973519b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# %%
import json
import os
from typing import Optional

import cohere
import json_repair
import numpy as np
from anthropic import Anthropic
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_openai import ChatOpenAI
from openai import OpenAI
from pydantic import BaseModel, Field
from rich import print as rprint

import utils
from app_configs import AVAILABLE_MODELS


class LLMOutput(BaseModel):
    content: str = Field(description="The content of the response")
    logprob: Optional[float] = Field(None, description="The log probability of the response")


def completion(model: str, system: str, prompt: str, response_format, logprobs: bool = False) -> str:
    """
    Generate a completion from an LLM provider with structured output.

    Args:
        model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
        system (str): System prompt/instructions for the model
        prompt (str): User prompt/input
        response_format: Pydantic model defining the expected response structure
        logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
            Note: Not supported by Anthropic models.

    Returns:
        dict: Contains:
            - output: The structured response matching response_format
            - logprob: (optional) Sum of log probabilities if logprobs=True
            - prob: (optional) Exponential of logprob if logprobs=True

    Raises:
        ValueError: If logprobs=True with Anthropic models
    """
    if model not in AVAILABLE_MODELS:
        raise ValueError(f"Model {model} not supported")
    model_name = AVAILABLE_MODELS[model]["model"]
    provider = model.split("/")[0]
    if provider == "Cohere":
        return _cohere_completion(model_name, system, prompt, response_format, logprobs)
    elif provider == "OpenAI":
        return _openai_completion(model_name, system, prompt, response_format, logprobs)
    elif provider == "Anthropic":
        if logprobs:
            raise ValueError("Anthropic does not support logprobs")
        return _anthropic_completion(model_name, system, prompt, response_format)
    else:
        raise ValueError(f"Provider {provider} not supported")


def _cohere_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content": prompt},
    ]
    client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
    response = client.chat(
        model=model,
        messages=messages,
        response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
        logprobs=logprobs,
    )
    output = {}
    output["content"] = response.message.content[0].text
    output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump()
    if logprobs:
        output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs)
        output["prob"] = np.exp(output["logprob"])
    return output


def _openai_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content": prompt},
    ]
    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    response = client.beta.chat.completions.parse(
        model=model,
        messages=messages,
        response_format=response_model,
        logprobs=logprobs,
    )
    output = {}
    output["content"] = response.choices[0].message.content
    output["output"] = response.choices[0].message.parsed.model_dump()
    if logprobs:
        output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content)
        output["prob"] = np.exp(output["logprob"])
    return output


def _anthropic_completion(model: str, system: str, prompt: str, response_model) -> str:
    llm = ChatAnthropic(model=model).with_structured_output(response_model, include_raw=True)
    output = llm.invoke([("system", system), ("human", prompt)])
    return {"content": output.raw, "output": output.parsed.model_dump()}


if __name__ == "__main__":

    class ExplainedAnswer(BaseModel):
        """
        The answer to the question and a terse explanation of the answer.
        """

        answer: str = Field(description="The short answer to the question")
        explanation: str = Field(description="5 words terse best explanation of the answer.")

    model = "Anthropic/claude-3-5-sonnet-20240620"
    system = "You are an accurate and concise explainer of scientific concepts."
    prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."

    # response = _cohere_completion("command-r", system, prompt, ExplainedAnswer, logprobs=True)
    response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
    rprint(response)

# %%