File size: 5,960 Bytes
973519b
 
 
 
 
 
 
 
 
 
2900a81
973519b
 
 
 
633b045
973519b
 
2900a81
 
 
 
 
 
 
 
 
973519b
 
 
 
 
2900a81
 
 
 
 
 
973519b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2900a81
 
 
 
 
973519b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2900a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
973519b
2900a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
973519b
2900a81
 
973519b
2900a81
973519b
 
 
 
 
 
 
 
 
2900a81
973519b
 
 
2900a81
 
 
973519b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# %%
import json
import os
from typing import Optional

import cohere
import numpy as np
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_openai import ChatOpenAI
from loguru import logger
from openai import OpenAI
from pydantic import BaseModel, Field
from rich import print as rprint

from app_configs import AVAILABLE_MODELS


def _openai_is_json_mode_supported(model_name: str) -> bool:
    if model_name.startswith("gpt-4"):
        return True
    if model_name.startswith("gpt-3.5"):
        return False
    logger.warning(f"OpenAI model {model_name} is not available in this app, skipping JSON mode, returning False")
    return False


class LLMOutput(BaseModel):
    content: str = Field(description="The content of the response")
    logprob: Optional[float] = Field(None, description="The log probability of the response")


def _get_langchain_chat_output(llm, system: str, prompt: str) -> str:
    output = llm.invoke([("system", system), ("human", prompt)])
    ai_message = output["raw"]
    content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
    content_str = json.dumps(content)
    return {"content": content_str, "output": output["parsed"].model_dump()}


def _cohere_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content": prompt},
    ]
    client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
    response = client.chat(
        model=model,
        messages=messages,
        response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
        logprobs=logprobs,
    )
    output = {}
    output["content"] = response.message.content[0].text
    output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump()
    if logprobs:
        output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs)
        output["prob"] = np.exp(output["logprob"])
    return output


def _openai_langchain_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
    llm = ChatOpenAI(model=model).with_structured_output(response_model, include_raw=True)
    return _get_langchain_chat_output(llm, system, prompt)


def _openai_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content": prompt},
    ]
    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    response = client.beta.chat.completions.parse(
        model=model,
        messages=messages,
        response_format=response_model,
        logprobs=logprobs,
    )
    output = {}
    output["content"] = response.choices[0].message.content
    output["output"] = response.choices[0].message.parsed.model_dump()
    if logprobs:
        output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content)
        output["prob"] = np.exp(output["logprob"])
    return output


def _anthropic_completion(model: str, system: str, prompt: str, response_model) -> str:
    llm = ChatAnthropic(model=model).with_structured_output(response_model, include_raw=True)
    return _get_langchain_chat_output(llm, system, prompt)


def completion(model: str, system: str, prompt: str, response_format, logprobs: bool = False) -> str:
    """
    Generate a completion from an LLM provider with structured output.

    Args:
        model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
        system (str): System prompt/instructions for the model
        prompt (str): User prompt/input
        response_format: Pydantic model defining the expected response structure
        logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
            Note: Not supported by Anthropic models.

    Returns:
        dict: Contains:
            - output: The structured response matching response_format
            - logprob: (optional) Sum of log probabilities if logprobs=True
            - prob: (optional) Exponential of logprob if logprobs=True

    Raises:
        ValueError: If logprobs=True with Anthropic models
    """
    if model not in AVAILABLE_MODELS:
        raise ValueError(f"Model {model} not supported")
    model_name = AVAILABLE_MODELS[model]["model"]
    provider = model.split("/")[0]
    if provider == "Cohere":
        return _cohere_completion(model_name, system, prompt, response_format, logprobs)
    elif provider == "OpenAI":
        if _openai_is_json_mode_supported(model_name):
            return _openai_completion(model_name, system, prompt, response_format, logprobs)
        else:
            return _openai_langchain_completion(model_name, system, prompt, response_format, logprobs)
    elif provider == "Anthropic":
        if logprobs:
            raise ValueError("Anthropic does not support logprobs")
        return _anthropic_completion(model_name, system, prompt, response_format)
    else:
        raise ValueError(f"Provider {provider} not supported")


# %%
if __name__ == "__main__":
    from tqdm import tqdm

    class ExplainedAnswer(BaseModel):
        """
        The answer to the question and a terse explanation of the answer.
        """

        answer: str = Field(description="The short answer to the question")
        explanation: str = Field(description="5 words terse best explanation of the answer.")

    models = AVAILABLE_MODELS.keys()
    system = "You are an accurate and concise explainer of scientific concepts."
    prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."

    for model in tqdm(models):
        response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
        rprint(response)

# %%