Maharshi Gor commited on
Commit
2900a81
·
1 Parent(s): d15e788

Bugfix: llm model calls for Anthropic and gpt-3.5

Browse files
Files changed (1) hide show
  1. src/llms.py +68 -44
src/llms.py CHANGED
@@ -4,60 +4,38 @@ import os
4
  from typing import Optional
5
 
6
  import cohere
7
- import json_repair
8
  import numpy as np
9
- from anthropic import Anthropic
10
  from langchain_anthropic import ChatAnthropic
11
  from langchain_cohere import ChatCohere
12
  from langchain_openai import ChatOpenAI
 
13
  from openai import OpenAI
14
  from pydantic import BaseModel, Field
15
  from rich import print as rprint
16
 
17
- import utils
18
  from app_configs import AVAILABLE_MODELS
19
 
20
 
 
 
 
 
 
 
 
 
 
21
  class LLMOutput(BaseModel):
22
  content: str = Field(description="The content of the response")
23
  logprob: Optional[float] = Field(None, description="The log probability of the response")
24
 
25
 
26
- def completion(model: str, system: str, prompt: str, response_format, logprobs: bool = False) -> str:
27
- """
28
- Generate a completion from an LLM provider with structured output.
29
-
30
- Args:
31
- model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
32
- system (str): System prompt/instructions for the model
33
- prompt (str): User prompt/input
34
- response_format: Pydantic model defining the expected response structure
35
- logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
36
- Note: Not supported by Anthropic models.
37
-
38
- Returns:
39
- dict: Contains:
40
- - output: The structured response matching response_format
41
- - logprob: (optional) Sum of log probabilities if logprobs=True
42
- - prob: (optional) Exponential of logprob if logprobs=True
43
-
44
- Raises:
45
- ValueError: If logprobs=True with Anthropic models
46
- """
47
- if model not in AVAILABLE_MODELS:
48
- raise ValueError(f"Model {model} not supported")
49
- model_name = AVAILABLE_MODELS[model]["model"]
50
- provider = model.split("/")[0]
51
- if provider == "Cohere":
52
- return _cohere_completion(model_name, system, prompt, response_format, logprobs)
53
- elif provider == "OpenAI":
54
- return _openai_completion(model_name, system, prompt, response_format, logprobs)
55
- elif provider == "Anthropic":
56
- if logprobs:
57
- raise ValueError("Anthropic does not support logprobs")
58
- return _anthropic_completion(model_name, system, prompt, response_format)
59
- else:
60
- raise ValueError(f"Provider {provider} not supported")
61
 
62
 
63
  def _cohere_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
@@ -81,6 +59,11 @@ def _cohere_completion(model: str, system: str, prompt: str, response_model, log
81
  return output
82
 
83
 
 
 
 
 
 
84
  def _openai_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
85
  messages = [
86
  {"role": "system", "content": system},
@@ -104,11 +87,52 @@ def _openai_completion(model: str, system: str, prompt: str, response_model, log
104
 
105
  def _anthropic_completion(model: str, system: str, prompt: str, response_model) -> str:
106
  llm = ChatAnthropic(model=model).with_structured_output(response_model, include_raw=True)
107
- output = llm.invoke([("system", system), ("human", prompt)])
108
- return {"content": output.raw, "output": output.parsed.model_dump()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
 
 
111
  if __name__ == "__main__":
 
112
 
113
  class ExplainedAnswer(BaseModel):
114
  """
@@ -118,12 +142,12 @@ if __name__ == "__main__":
118
  answer: str = Field(description="The short answer to the question")
119
  explanation: str = Field(description="5 words terse best explanation of the answer.")
120
 
121
- model = "Anthropic/claude-3-5-sonnet-20240620"
122
  system = "You are an accurate and concise explainer of scientific concepts."
123
  prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
124
 
125
- # response = _cohere_completion("command-r", system, prompt, ExplainedAnswer, logprobs=True)
126
- response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
127
- rprint(response)
128
 
129
  # %%
 
4
  from typing import Optional
5
 
6
  import cohere
 
7
  import numpy as np
 
8
  from langchain_anthropic import ChatAnthropic
9
  from langchain_cohere import ChatCohere
10
  from langchain_openai import ChatOpenAI
11
+ from loguru import logger
12
  from openai import OpenAI
13
  from pydantic import BaseModel, Field
14
  from rich import print as rprint
15
 
 
16
  from app_configs import AVAILABLE_MODELS
17
 
18
 
19
+ def _openai_is_json_mode_supported(model_name: str) -> bool:
20
+ if model_name.startswith("gpt-4"):
21
+ return True
22
+ if model_name.startswith("gpt-3.5"):
23
+ return False
24
+ logger.warning(f"OpenAI model {model_name} is not available in this app, skipping JSON mode, returning False")
25
+ return False
26
+
27
+
28
  class LLMOutput(BaseModel):
29
  content: str = Field(description="The content of the response")
30
  logprob: Optional[float] = Field(None, description="The log probability of the response")
31
 
32
 
33
+ def _get_langchain_chat_output(llm, system: str, prompt: str) -> str:
34
+ output = llm.invoke([("system", system), ("human", prompt)])
35
+ ai_message = output["raw"]
36
+ content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
37
+ content_str = json.dumps(content)
38
+ return {"content": content_str, "output": output["parsed"].model_dump()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def _cohere_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
 
59
  return output
60
 
61
 
62
+ def _openai_langchain_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
63
+ llm = ChatOpenAI(model=model).with_structured_output(response_model, include_raw=True)
64
+ return _get_langchain_chat_output(llm, system, prompt)
65
+
66
+
67
  def _openai_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
68
  messages = [
69
  {"role": "system", "content": system},
 
87
 
88
  def _anthropic_completion(model: str, system: str, prompt: str, response_model) -> str:
89
  llm = ChatAnthropic(model=model).with_structured_output(response_model, include_raw=True)
90
+ return _get_langchain_chat_output(llm, system, prompt)
91
+
92
+
93
+ def completion(model: str, system: str, prompt: str, response_format, logprobs: bool = False) -> str:
94
+ """
95
+ Generate a completion from an LLM provider with structured output.
96
+
97
+ Args:
98
+ model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
99
+ system (str): System prompt/instructions for the model
100
+ prompt (str): User prompt/input
101
+ response_format: Pydantic model defining the expected response structure
102
+ logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
103
+ Note: Not supported by Anthropic models.
104
+
105
+ Returns:
106
+ dict: Contains:
107
+ - output: The structured response matching response_format
108
+ - logprob: (optional) Sum of log probabilities if logprobs=True
109
+ - prob: (optional) Exponential of logprob if logprobs=True
110
 
111
+ Raises:
112
+ ValueError: If logprobs=True with Anthropic models
113
+ """
114
+ if model not in AVAILABLE_MODELS:
115
+ raise ValueError(f"Model {model} not supported")
116
+ model_name = AVAILABLE_MODELS[model]["model"]
117
+ provider = model.split("/")[0]
118
+ if provider == "Cohere":
119
+ return _cohere_completion(model_name, system, prompt, response_format, logprobs)
120
+ elif provider == "OpenAI":
121
+ if _openai_is_json_mode_supported(model_name):
122
+ return _openai_completion(model_name, system, prompt, response_format, logprobs)
123
+ else:
124
+ return _openai_langchain_completion(model_name, system, prompt, response_format, logprobs)
125
+ elif provider == "Anthropic":
126
+ if logprobs:
127
+ raise ValueError("Anthropic does not support logprobs")
128
+ return _anthropic_completion(model_name, system, prompt, response_format)
129
+ else:
130
+ raise ValueError(f"Provider {provider} not supported")
131
 
132
+
133
+ # %%
134
  if __name__ == "__main__":
135
+ from tqdm import tqdm
136
 
137
  class ExplainedAnswer(BaseModel):
138
  """
 
142
  answer: str = Field(description="The short answer to the question")
143
  explanation: str = Field(description="5 words terse best explanation of the answer.")
144
 
145
+ models = AVAILABLE_MODELS.keys()
146
  system = "You are an accurate and concise explainer of scientific concepts."
147
  prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
148
 
149
+ for model in tqdm(models):
150
+ response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
151
+ rprint(response)
152
 
153
  # %%