File size: 10,274 Bytes
973519b
3a1af80
973519b
 
3a1af80
973519b
 
 
 
 
3a1af80
973519b
2900a81
973519b
 
 
 
0bab47c
3a1af80
 
 
 
973519b
 
2900a81
 
 
 
 
 
 
 
 
973519b
 
 
 
 
3a1af80
2900a81
 
 
 
 
973519b
 
3a1af80
 
 
973519b
 
 
 
 
 
 
 
 
 
3a1af80
973519b
 
 
 
 
 
 
 
 
 
3a1af80
 
 
 
2900a81
 
 
3a1af80
 
 
973519b
 
 
 
 
 
 
 
 
 
3a1af80
973519b
 
 
 
 
 
 
 
 
 
3a1af80
 
 
 
2900a81
 
 
3a1af80
 
 
2900a81
3a1af80
2900a81
 
 
 
 
 
 
 
 
 
 
 
 
 
973519b
2900a81
 
 
 
 
 
3a1af80
2900a81
 
3a1af80
 
 
2900a81
3a1af80
2900a81
 
3a1af80
 
2900a81
 
973519b
2900a81
3a1af80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0ae1a9
 
 
 
 
 
3a1af80
 
d0ae1a9
3a1af80
 
 
d0ae1a9
3a1af80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2900a81
973519b
2900a81
973519b
 
 
 
 
 
 
 
 
3a1af80
973519b
 
 
3a1af80
 
 
 
 
 
 
 
 
 
2900a81
 
 
973519b
3a1af80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
973519b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# %%

import json
import os
from typing import Any, Optional

import cohere
import numpy as np
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_core.language_models import BaseChatModel
from langchain_openai import ChatOpenAI
from loguru import logger
from openai import OpenAI
from pydantic import BaseModel, Field
from rich import print as rprint

from .configs import AVAILABLE_MODELS
from .llmcache import LLMCache

# Initialize global cache
llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache")


def _openai_is_json_mode_supported(model_name: str) -> bool:
    if model_name.startswith("gpt-4"):
        return True
    if model_name.startswith("gpt-3.5"):
        return False
    logger.warning(f"OpenAI model {model_name} is not available in this app, skipping JSON mode, returning False")
    return False


class LLMOutput(BaseModel):
    content: str = Field(description="The content of the response")
    logprob: Optional[float] = Field(None, description="The log probability of the response")


def _get_langchain_chat_output(llm: BaseChatModel, system: str, prompt: str) -> str:
    output = llm.invoke([("system", system), ("human", prompt)])
    ai_message = output["raw"]
    content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
    content_str = json.dumps(content)
    return {"content": content_str, "output": output["parsed"].model_dump()}


def _cohere_completion(
    model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
) -> str:
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content": prompt},
    ]
    client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
    response = client.chat(
        model=model,
        messages=messages,
        response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
        logprobs=logprobs,
        temperature=temperature,
    )
    output = {}
    output["content"] = response.message.content[0].text
    output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump()
    if logprobs:
        output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs)
        output["prob"] = np.exp(output["logprob"])
    return output


def _openai_langchain_completion(
    model: str, system: str, prompt: str, response_model, temperature: float | None = None
) -> str:
    llm = ChatOpenAI(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
    return _get_langchain_chat_output(llm, system, prompt)


def _openai_completion(
    model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
) -> str:
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content": prompt},
    ]
    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    response = client.beta.chat.completions.parse(
        model=model,
        messages=messages,
        response_format=response_model,
        logprobs=logprobs,
        temperature=temperature,
    )
    output = {}
    output["content"] = response.choices[0].message.content
    output["output"] = response.choices[0].message.parsed.model_dump()
    if logprobs:
        output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content)
        output["prob"] = np.exp(output["logprob"])
    return output


def _anthropic_completion(
    model: str, system: str, prompt: str, response_model, temperature: float | None = None
) -> str:
    llm = ChatAnthropic(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
    return _get_langchain_chat_output(llm, system, prompt)


def _llm_completion(
    model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
) -> dict[str, Any]:
    """
    Generate a completion from an LLM provider with structured output without caching.

    Args:
        model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
        system (str): System prompt/instructions for the model
        prompt (str): User prompt/input
        response_format: Pydantic model defining the expected response structure
        logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
            Note: Not supported by Anthropic models.

    Returns:
        dict: Contains:
            - output: The structured response matching response_format
            - logprob: (optional) Sum of log probabilities if logprobs=True
            - prob: (optional) Exponential of logprob if logprobs=True

    Raises:
        ValueError: If logprobs=True with Anthropic models
    """
    model_name = AVAILABLE_MODELS[model]["model"]
    provider = model.split("/")[0]
    if provider == "Cohere":
        return _cohere_completion(model_name, system, prompt, response_format, temperature, logprobs)
    elif provider == "OpenAI":
        if _openai_is_json_mode_supported(model_name):
            return _openai_completion(model_name, system, prompt, response_format, temperature, logprobs)
        elif logprobs:
            raise ValueError(f"{model} does not support logprobs feature.")
        else:
            return _openai_langchain_completion(model_name, system, prompt, response_format, temperature)
    elif provider == "Anthropic":
        if logprobs:
            raise ValueError("Anthropic models do not support logprobs")
        return _anthropic_completion(model_name, system, prompt, response_format, temperature)
    else:
        raise ValueError(f"Provider {provider} not supported")


def completion(
    model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
) -> dict[str, Any]:
    """
    Generate a completion from an LLM provider with structured output with caching.

    Args:
        model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
        system (str): System prompt/instructions for the model
        prompt (str): User prompt/input
        response_format: Pydantic model defining the expected response structure
        logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
            Note: Not supported by Anthropic models.

    Returns:
        dict: Contains:
            - output: The structured response matching response_format
            - logprob: (optional) Sum of log probabilities if logprobs=True
            - prob: (optional) Exponential of logprob if logprobs=True

    Raises:
        ValueError: If logprobs=True with Anthropic models
    """
    if model not in AVAILABLE_MODELS:
        raise ValueError(f"Model {model} not supported")
    if logprobs and not AVAILABLE_MODELS[model].get("logprobs", False):
        logger.warning(f"{model} does not support logprobs feature, setting logprobs to False")
        logprobs = False

    # Check cache first
    cached_response = llm_cache.get(model, system, prompt, response_format, temperature)
    if cached_response and (not logprobs or cached_response.get("logprob")):
        logger.info(f"Cache hit for model {model}")
        return cached_response

    logger.info(f"Cache miss for model {model}, calling API. Logprobs: {logprobs}")

    # Continue with the original implementation for cache miss
    response = _llm_completion(model, system, prompt, response_format, temperature, logprobs)

    # Update cache with the new response
    llm_cache.set(
        model,
        system,
        prompt,
        response_format,
        temperature,
        response,
    )

    return response


# %%
if __name__ == "__main__":
    from tqdm import tqdm

    class ExplainedAnswer(BaseModel):
        """
        The answer to the question and a terse explanation of the answer.
        """

        answer: str = Field(description="The short answer to the question")
        explanation: str = Field(description="5 words terse best explanation of the answer.")

    models = list(AVAILABLE_MODELS.keys())[:1]  # Just use the first model for testing
    system = "You are an accurate and concise explainer of scientific concepts."
    prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."

    llm_cache = LLMCache(cache_dir=".", hf_repo="umdclip/advcal-llm-cache", reset=True)

    # First call - should be a cache miss
    logger.info("First call - should be a cache miss")
    for model in tqdm(models):
        response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
        rprint(response)

    # Second call - should be a cache hit
    logger.info("Second call - should be a cache hit")
    for model in tqdm(models):
        response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
        rprint(response)

    # Slightly different prompt - should be a cache miss
    logger.info("Different prompt - should be a cache miss")
    prompt2 = "Which planet is closest to the sun? Answer directly."
    for model in tqdm(models):
        response = completion(model, system, prompt2, ExplainedAnswer, logprobs=False)
        rprint(response)

    # Get cache entries count from SQLite
    try:
        cache_entries = llm_cache.get_all_entries()
        logger.info(f"Cache now has {len(cache_entries)} items")
    except Exception as e:
        logger.error(f"Failed to get cache entries: {e}")

    # Test adding entry with temperature parameter
    logger.info("Testing with temperature parameter")
    response = completion(models[0], system, "What is Mars?", ExplainedAnswer, temperature=0.7, logprobs=False)
    rprint(response)

    # Demonstrate forced sync to HF if repo is configured
    if llm_cache.hf_repo_id:
        logger.info("Forcing sync to HF dataset")
        try:
            llm_cache.sync_to_hf()
            logger.info("Successfully synced to HF dataset")
        except Exception as e:
            logger.exception(f"Failed to sync to HF: {e}")
    else:
        logger.info("HF repo not configured, skipping sync test")

# %%