Spaces:
Runtime error
Runtime error
"""Groq Chat wrapper.""" | |
from __future__ import annotations | |
import json | |
import os | |
import warnings | |
from operator import itemgetter | |
from typing import ( | |
Any, | |
AsyncIterator, | |
Callable, | |
Dict, | |
Iterator, | |
List, | |
Literal, | |
Mapping, | |
Optional, | |
Sequence, | |
Tuple, | |
Type, | |
TypedDict, | |
Union, | |
cast, | |
) | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.language_models import LanguageModelInput | |
from langchain_core.language_models.chat_models import ( | |
BaseChatModel, | |
LangSmithParams, | |
agenerate_from_stream, | |
generate_from_stream, | |
) | |
from langchain_core.messages import ( | |
AIMessage, | |
AIMessageChunk, | |
BaseMessage, | |
BaseMessageChunk, | |
ChatMessage, | |
ChatMessageChunk, | |
FunctionMessage, | |
FunctionMessageChunk, | |
HumanMessage, | |
HumanMessageChunk, | |
InvalidToolCall, | |
SystemMessage, | |
SystemMessageChunk, | |
ToolCall, | |
ToolMessage, | |
ToolMessageChunk, | |
) | |
from langchain_core.output_parsers import ( | |
JsonOutputParser, | |
PydanticOutputParser, | |
) | |
from langchain_core.output_parsers.base import OutputParserLike | |
from langchain_core.output_parsers.openai_tools import ( | |
JsonOutputKeyToolsParser, | |
PydanticToolsParser, | |
make_invalid_tool_call, | |
parse_tool_call, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator | |
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough | |
from langchain_core.tools import BaseTool | |
from langchain_core.utils import ( | |
convert_to_secret_str, | |
get_from_dict_or_env, | |
get_pydantic_field_names, | |
) | |
from langchain_core.utils.function_calling import ( | |
convert_to_openai_function, | |
convert_to_openai_tool, | |
) | |
class ChatGroq(BaseChatModel): | |
"""`Groq` Chat large language models API. | |
To use, you should have the | |
environment variable ``GROQ_API_KEY`` set with your API key. | |
Any parameters that are valid to be passed to the groq.create call can be passed | |
in, even if not explicitly saved on this class. | |
Example: | |
.. code-block:: python | |
from langchain_groq import ChatGroq | |
model = ChatGroq(model_name="mixtral-8x7b-32768") | |
""" | |
client: Any = Field(default=None, exclude=True) #: :meta private: | |
async_client: Any = Field(default=None, exclude=True) #: :meta private: | |
model_name: str = Field(default="mixtral-8x7b-32768", alias="model") | |
"""Model name to use.""" | |
temperature: float = 0.7 | |
"""What sampling temperature to use.""" | |
model_kwargs: Dict[str, Any] = Field(default_factory=dict) | |
"""Holds any model parameters valid for `create` call not explicitly specified.""" | |
groq_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") | |
"""Automatically inferred from env var `groq_API_KEY` if not provided.""" | |
groq_api_base: Optional[str] = Field(default=None, alias="base_url") | |
"""Base URL path for API requests, leave blank if not using a proxy or service | |
emulator.""" | |
# to support explicit proxy for Groq | |
groq_proxy: Optional[str] = None | |
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( | |
default=None, alias="timeout" | |
) | |
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or | |
None.""" | |
max_retries: int = 2 | |
"""Maximum number of retries to make when generating.""" | |
streaming: bool = False | |
"""Whether to stream the results or not.""" | |
n: int = 1 | |
"""Number of chat completions to generate for each prompt.""" | |
max_tokens: Optional[int] = None | |
"""Maximum number of tokens to generate.""" | |
stop: Optional[List[str]] = Field(None, alias="stop_sequences") | |
"""Default stop sequences.""" | |
default_headers: Union[Mapping[str, str], None] = None | |
default_query: Union[Mapping[str, object], None] = None | |
# Configure a custom httpx client. See the | |
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details. | |
http_client: Union[Any, None] = None | |
"""Optional httpx.Client.""" | |
http_async_client: Union[Any, None] = None | |
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify | |
http_client as well if you'd like a custom client for sync invocations.""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
allow_population_by_field_name = True | |
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |
"""Build extra kwargs from additional params that were passed in.""" | |
all_required_field_names = get_pydantic_field_names(cls) | |
extra = values.get("model_kwargs", {}) | |
for field_name in list(values): | |
if field_name in extra: | |
raise ValueError(f"Found {field_name} supplied twice.") | |
if field_name not in all_required_field_names: | |
warnings.warn( | |
f"""WARNING! {field_name} is not default parameter. | |
{field_name} was transferred to model_kwargs. | |
Please confirm that {field_name} is what you intended.""" | |
) | |
extra[field_name] = values.pop(field_name) | |
invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) | |
if invalid_model_kwargs: | |
raise ValueError( | |
f"Parameters {invalid_model_kwargs} should be specified explicitly. " | |
f"Instead they were passed in as part of `model_kwargs` parameter." | |
) | |
values["model_kwargs"] = extra | |
return values | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that api key and python package exists in environment.""" | |
if values["n"] < 1: | |
raise ValueError("n must be at least 1.") | |
if values["n"] > 1 and values["streaming"]: | |
raise ValueError("n must be 1 when streaming.") | |
if values["temperature"] == 0: | |
values["temperature"] = 1e-8 | |
values["groq_api_key"] = convert_to_secret_str( | |
get_from_dict_or_env(values, "groq_api_key", "GROQ_API_KEY") | |
) | |
values["groq_api_base"] = values["groq_api_base"] or os.getenv("GROQ_API_BASE") | |
values["groq_proxy"] = values["groq_proxy"] = os.getenv("GROQ_PROXY") | |
client_params = { | |
"api_key": values["groq_api_key"].get_secret_value(), | |
"base_url": values["groq_api_base"], | |
"timeout": values["request_timeout"], | |
"max_retries": values["max_retries"], | |
"default_headers": values["default_headers"], | |
"default_query": values["default_query"], | |
} | |
try: | |
import groq | |
sync_specific = {"http_client": values["http_client"]} | |
if not values.get("client"): | |
values["client"] = groq.Groq( | |
**client_params, **sync_specific | |
).chat.completions | |
if not values.get("async_client"): | |
async_specific = {"http_client": values["http_async_client"]} | |
values["async_client"] = groq.AsyncGroq( | |
**client_params, **async_specific | |
).chat.completions | |
except ImportError: | |
raise ImportError( | |
"Could not import groq python package. " | |
"Please install it with `pip install groq`." | |
) | |
return values | |
# | |
# Serializable class method overrides | |
# | |
def lc_secrets(self) -> Dict[str, str]: | |
return {"groq_api_key": "GROQ_API_KEY"} | |
def is_lc_serializable(cls) -> bool: | |
"""Return whether this model can be serialized by Langchain.""" | |
return True | |
# | |
# BaseChatModel method overrides | |
# | |
def _llm_type(self) -> str: | |
"""Return type of model.""" | |
return "groq-chat" | |
def _get_ls_params( | |
self, stop: Optional[List[str]] = None, **kwargs: Any | |
) -> LangSmithParams: | |
"""Get standard params for tracing.""" | |
params = self._get_invocation_params(stop=stop, **kwargs) | |
ls_params = LangSmithParams( | |
ls_provider="groq", | |
ls_model_name=self.model_name, | |
ls_model_type="chat", | |
ls_temperature=params.get("temperature", self.temperature), | |
) | |
if ls_max_tokens := params.get("max_tokens", self.max_tokens): | |
ls_params["ls_max_tokens"] = ls_max_tokens | |
if ls_stop := stop or params.get("stop", None) or self.stop: | |
ls_params["ls_stop"] = ls_stop | |
return ls_params | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
if self.streaming: | |
stream_iter = self._stream( | |
messages, stop=stop, run_manager=run_manager, **kwargs | |
) | |
return generate_from_stream(stream_iter) | |
message_dicts, params = self._create_message_dicts(messages, stop) | |
params = { | |
**params, | |
**kwargs, | |
} | |
response = self.client.create(messages=message_dicts, **params) | |
return self._create_chat_result(response) | |
async def _agenerate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
if self.streaming: | |
stream_iter = self._astream( | |
messages, stop=stop, run_manager=run_manager, **kwargs | |
) | |
return await agenerate_from_stream(stream_iter) | |
message_dicts, params = self._create_message_dicts(messages, stop) | |
params = { | |
**params, | |
**kwargs, | |
} | |
response = await self.async_client.create(messages=message_dicts, **params) | |
return self._create_chat_result(response) | |
def _stream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
message_dicts, params = self._create_message_dicts(messages, stop) | |
# groq api does not support streaming with tools yet | |
if "tools" in kwargs: | |
response = self.client.create( | |
messages=message_dicts, **{**params, **kwargs} | |
) | |
chat_result = self._create_chat_result(response) | |
generation = chat_result.generations[0] | |
message = generation.message | |
tool_call_chunks = [ | |
{ | |
"name": rtc["function"].get("name"), | |
"args": rtc["function"].get("arguments"), | |
"id": rtc.get("id"), | |
"index": rtc.get("index"), | |
} | |
for rtc in message.additional_kwargs.get("tool_calls", []) | |
] | |
chunk_ = ChatGenerationChunk( | |
message=AIMessageChunk( | |
content=message.content, | |
additional_kwargs=message.additional_kwargs, | |
tool_call_chunks=tool_call_chunks, | |
), | |
generation_info=generation.generation_info, | |
) | |
if run_manager: | |
geninfo = chunk_.generation_info or {} | |
run_manager.on_llm_new_token( | |
chunk_.text, | |
chunk=chunk_, | |
logprobs=geninfo.get("logprobs"), | |
) | |
yield chunk_ | |
return | |
params = {**params, **kwargs, "stream": True} | |
default_chunk_class = AIMessageChunk | |
for chunk in self.client.create(messages=message_dicts, **params): | |
if not isinstance(chunk, dict): | |
chunk = chunk.dict() | |
if len(chunk["choices"]) == 0: | |
continue | |
choice = chunk["choices"][0] | |
chunk = _convert_delta_to_message_chunk( | |
choice["delta"], default_chunk_class | |
) | |
generation_info = {} | |
if finish_reason := choice.get("finish_reason"): | |
generation_info["finish_reason"] = finish_reason | |
logprobs = choice.get("logprobs") | |
if logprobs: | |
generation_info["logprobs"] = logprobs | |
default_chunk_class = chunk.__class__ | |
chunk = ChatGenerationChunk( | |
message=chunk, generation_info=generation_info or None | |
) | |
if run_manager: | |
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) | |
yield chunk | |
async def _astream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> AsyncIterator[ChatGenerationChunk]: | |
message_dicts, params = self._create_message_dicts(messages, stop) | |
# groq api does not support streaming with tools yet | |
if "tools" in kwargs: | |
response = await self.async_client.create( | |
messages=message_dicts, **{**params, **kwargs} | |
) | |
chat_result = self._create_chat_result(response) | |
generation = chat_result.generations[0] | |
message = generation.message | |
tool_call_chunks = [ | |
{ | |
"name": rtc["function"].get("name"), | |
"args": rtc["function"].get("arguments"), | |
"id": rtc.get("id"), | |
"index": rtc.get("index"), | |
} | |
for rtc in message.additional_kwargs.get("tool_calls", []) | |
] | |
chunk_ = ChatGenerationChunk( | |
message=AIMessageChunk( | |
content=message.content, | |
additional_kwargs=message.additional_kwargs, | |
tool_call_chunks=tool_call_chunks, | |
), | |
generation_info=generation.generation_info, | |
) | |
if run_manager: | |
geninfo = chunk_.generation_info or {} | |
await run_manager.on_llm_new_token( | |
chunk_.text, | |
chunk=chunk_, | |
logprobs=geninfo.get("logprobs"), | |
) | |
yield chunk_ | |
return | |
params = {**params, **kwargs, "stream": True} | |
default_chunk_class = AIMessageChunk | |
async for chunk in await self.async_client.create( | |
messages=message_dicts, **params | |
): | |
if not isinstance(chunk, dict): | |
chunk = chunk.dict() | |
if len(chunk["choices"]) == 0: | |
continue | |
choice = chunk["choices"][0] | |
chunk = _convert_delta_to_message_chunk( | |
choice["delta"], default_chunk_class | |
) | |
generation_info = {} | |
if finish_reason := choice.get("finish_reason"): | |
generation_info["finish_reason"] = finish_reason | |
logprobs = choice.get("logprobs") | |
if logprobs: | |
generation_info["logprobs"] = logprobs | |
default_chunk_class = chunk.__class__ | |
chunk = ChatGenerationChunk( | |
message=chunk, generation_info=generation_info or None | |
) | |
if run_manager: | |
await run_manager.on_llm_new_token( | |
token=chunk.text, chunk=chunk, logprobs=logprobs | |
) | |
yield chunk | |
# | |
# Internal methods | |
# | |
def _default_params(self) -> Dict[str, Any]: | |
"""Get the default parameters for calling Groq API.""" | |
params = { | |
"model": self.model_name, | |
"stream": self.streaming, | |
"n": self.n, | |
"temperature": self.temperature, | |
"stop": self.stop, | |
**self.model_kwargs, | |
} | |
if self.max_tokens is not None: | |
params["max_tokens"] = self.max_tokens | |
return params | |
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: | |
generations = [] | |
if not isinstance(response, dict): | |
response = response.dict() | |
for res in response["choices"]: | |
message = _convert_dict_to_message(res["message"]) | |
generation_info = dict(finish_reason=res.get("finish_reason")) | |
if "logprobs" in res: | |
generation_info["logprobs"] = res["logprobs"] | |
gen = ChatGeneration( | |
message=message, | |
generation_info=generation_info, | |
) | |
generations.append(gen) | |
token_usage = response.get("usage", {}) | |
llm_output = { | |
"token_usage": token_usage, | |
"model_name": self.model_name, | |
"system_fingerprint": response.get("system_fingerprint", ""), | |
} | |
return ChatResult(generations=generations, llm_output=llm_output) | |
def _create_message_dicts( | |
self, messages: List[BaseMessage], stop: Optional[List[str]] | |
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
params = self._default_params | |
if stop is not None: | |
params["stop"] = stop | |
message_dicts = [_convert_message_to_dict(m) for m in messages] | |
return message_dicts, params | |
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: | |
overall_token_usage: dict = {} | |
system_fingerprint = None | |
for output in llm_outputs: | |
if output is None: | |
# Happens in streaming | |
continue | |
token_usage = output["token_usage"] | |
if token_usage is not None: | |
for k, v in token_usage.items(): | |
if k in overall_token_usage and v is not None: | |
overall_token_usage[k] += v | |
else: | |
overall_token_usage[k] = v | |
if system_fingerprint is None: | |
system_fingerprint = output.get("system_fingerprint") | |
combined = {"token_usage": overall_token_usage, "model_name": self.model_name} | |
if system_fingerprint: | |
combined["system_fingerprint"] = system_fingerprint | |
return combined | |
def bind_functions( | |
self, | |
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], | |
function_call: Optional[ | |
Union[_FunctionCall, str, Literal["auto", "none"]] | |
] = None, | |
**kwargs: Any, | |
) -> Runnable[LanguageModelInput, BaseMessage]: | |
"""Bind functions (and other objects) to this chat model. | |
Model is compatible with OpenAI function-calling API. | |
NOTE: Using bind_tools is recommended instead, as the `functions` and | |
`function_call` request parameters are officially deprecated. | |
Args: | |
functions: A list of function definitions to bind to this chat model. | |
Can be a dictionary, pydantic model, or callable. Pydantic | |
models and callables will be automatically converted to | |
their schema dictionary representation. | |
function_call: Which function to require the model to call. | |
Must be the name of the single provided function or | |
"auto" to automatically determine which function to call | |
(if any). | |
**kwargs: Any additional parameters to pass to the | |
:class:`~langchain.runnable.Runnable` constructor. | |
""" | |
formatted_functions = [convert_to_openai_function(fn) for fn in functions] | |
if function_call is not None: | |
function_call = ( | |
{"name": function_call} | |
if isinstance(function_call, str) | |
and function_call not in ("auto", "none") | |
else function_call | |
) | |
if isinstance(function_call, dict) and len(formatted_functions) != 1: | |
raise ValueError( | |
"When specifying `function_call`, you must provide exactly one " | |
"function." | |
) | |
if ( | |
isinstance(function_call, dict) | |
and formatted_functions[0]["name"] != function_call["name"] | |
): | |
raise ValueError( | |
f"Function call {function_call} was specified, but the only " | |
f"provided function was {formatted_functions[0]['name']}." | |
) | |
kwargs = {**kwargs, "function_call": function_call} | |
return super().bind( | |
functions=formatted_functions, | |
**kwargs, | |
) | |
def bind_tools( | |
self, | |
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], | |
*, | |
tool_choice: Optional[ | |
Union[dict, str, Literal["auto", "any", "none"], bool] | |
] = None, | |
**kwargs: Any, | |
) -> Runnable[LanguageModelInput, BaseMessage]: | |
"""Bind tool-like objects to this chat model. | |
Args: | |
tools: A list of tool definitions to bind to this chat model. | |
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic | |
models, callables, and BaseTools will be automatically converted to | |
their schema dictionary representation. | |
tool_choice: Which tool to require the model to call. | |
Must be the name of the single provided function, | |
"auto" to automatically determine which function to call | |
with the option to not call any function, "any" to enforce that some | |
function is called, or a dict of the form: | |
{"type": "function", "function": {"name": <<tool_name>>}}. | |
**kwargs: Any additional parameters to pass to the | |
:class:`~langchain.runnable.Runnable` constructor. | |
""" | |
formatted_tools = [convert_to_openai_tool(tool) for tool in tools] | |
if tool_choice is not None and tool_choice: | |
if isinstance(tool_choice, str) and ( | |
tool_choice not in ("auto", "any", "none") | |
): | |
tool_choice = {"type": "function", "function": {"name": tool_choice}} | |
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1): | |
raise ValueError( | |
"When specifying `tool_choice`, you must provide exactly one " | |
f"tool. Received {len(formatted_tools)} tools." | |
) | |
if isinstance(tool_choice, dict) and ( | |
formatted_tools[0]["function"]["name"] | |
!= tool_choice["function"]["name"] | |
): | |
raise ValueError( | |
f"Tool choice {tool_choice} was specified, but the only " | |
f"provided tool was {formatted_tools[0]['function']['name']}." | |
) | |
if isinstance(tool_choice, bool): | |
if len(tools) > 1: | |
raise ValueError( | |
"tool_choice can only be True when there is one tool. Received " | |
f"{len(tools)} tools." | |
) | |
tool_name = formatted_tools[0]["function"]["name"] | |
tool_choice = { | |
"type": "function", | |
"function": {"name": tool_name}, | |
} | |
kwargs["tool_choice"] = tool_choice | |
return super().bind(tools=formatted_tools, **kwargs) | |
def with_structured_output( | |
self, | |
schema: Optional[Union[Dict, Type[BaseModel]]] = None, | |
*, | |
method: Literal["function_calling", "json_mode"] = "function_calling", | |
include_raw: bool = False, | |
**kwargs: Any, | |
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: | |
"""Model wrapper that returns outputs formatted to match the given schema. | |
Args: | |
schema: The output schema as a dict or a Pydantic class. If a Pydantic class | |
then the model output will be an object of that class. If a dict then | |
the model output will be a dict. With a Pydantic class the returned | |
attributes will be validated, whereas with a dict they will not be. If | |
`method` is "function_calling" and `schema` is a dict, then the dict | |
must match the OpenAI function-calling spec. | |
method: The method for steering model generation, either "function_calling" | |
or "json_mode". If "function_calling" then the schema will be converted | |
to a OpenAI function and the returned model will make use of the | |
function-calling API. If "json_mode" then Groq's JSON mode will be | |
used. Note that if using "json_mode" then you must include instructions | |
for formatting the output into the desired schema into the model call. | |
include_raw: If False then only the parsed structured output is returned. If | |
an error occurs during model output parsing it will be raised. If True | |
then both the raw model response (a BaseMessage) and the parsed model | |
response will be returned. If an error occurs during output parsing it | |
will be caught and returned as well. The final output is always a dict | |
with keys "raw", "parsed", and "parsing_error". | |
Returns: | |
A Runnable that takes any ChatModel input and returns as output: | |
If include_raw is True then a dict with keys: | |
raw: BaseMessage | |
parsed: Optional[_DictOrPydantic] | |
parsing_error: Optional[BaseException] | |
If include_raw is False then just _DictOrPydantic is returned, | |
where _DictOrPydantic depends on the schema: | |
If schema is a Pydantic class then _DictOrPydantic is the Pydantic | |
class. | |
If schema is a dict then _DictOrPydantic is a dict. | |
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False): | |
.. code-block:: python | |
from langchain_groq import ChatGroq | |
from langchain_core.pydantic_v1 import BaseModel | |
class AnswerWithJustification(BaseModel): | |
'''An answer to the user question along with justification for the answer.''' | |
answer: str | |
justification: str | |
llm = ChatGroq(temperature=0) | |
structured_llm = llm.with_structured_output(AnswerWithJustification) | |
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") | |
# -> AnswerWithJustification( | |
# answer='A pound of bricks and a pound of feathers weigh the same.' | |
# justification="Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same." | |
# ) | |
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True): | |
.. code-block:: python | |
from langchain_groq import ChatGroq | |
from langchain_core.pydantic_v1 import BaseModel | |
class AnswerWithJustification(BaseModel): | |
'''An answer to the user question along with justification for the answer.''' | |
answer: str | |
justification: str | |
llm = ChatGroq(temperature=0) | |
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True) | |
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") | |
# -> { | |
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_01htjn3cspevxbqc1d7nkk8wab', 'function': {'arguments': '{"answer": "A pound of bricks and a pound of feathers weigh the same.", "justification": "Both a pound of bricks and a pound of feathers have been defined to have the same weight. The \'pound\' is a unit of weight, so any two things that are described as weighing a pound will weigh the same.", "unit": "pounds"}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}, id='run-456beee6-65f6-4e80-88af-a6065480822c-0'), | |
# 'parsed': AnswerWithJustification(answer='A pound of bricks and a pound of feathers weigh the same.', justification="Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same."), | |
# 'parsing_error': None | |
# } | |
Example: Function-calling, dict schema (method="function_calling", include_raw=False): | |
.. code-block:: python | |
from langchain_groq import ChatGroq | |
from langchain_core.pydantic_v1 import BaseModel | |
from langchain_core.utils.function_calling import convert_to_openai_tool | |
class AnswerWithJustification(BaseModel): | |
'''An answer to the user question along with justification for the answer.''' | |
answer: str | |
justification: str | |
dict_schema = convert_to_openai_tool(AnswerWithJustification) | |
llm = ChatGroq(temperature=0) | |
structured_llm = llm.with_structured_output(dict_schema) | |
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") | |
# -> { | |
# 'answer': 'A pound of bricks and a pound of feathers weigh the same.', | |
# 'justification': "Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same.", 'unit': 'pounds'} | |
# } | |
Example: JSON mode, Pydantic schema (method="json_mode", include_raw=True): | |
.. code-block:: | |
from langchain_groq import ChatGroq | |
from langchain_core.pydantic_v1 import BaseModel | |
class AnswerWithJustification(BaseModel): | |
answer: str | |
justification: str | |
llm = ChatGroq(temperature=0) | |
structured_llm = llm.with_structured_output( | |
AnswerWithJustification, | |
method="json_mode", | |
include_raw=True | |
) | |
structured_llm.invoke( | |
"Answer the following question. " | |
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" | |
"What's heavier a pound of bricks or a pound of feathers?" | |
) | |
# -> { | |
# 'raw': AIMessage(content='{\n "answer": "A pound of bricks is the same weight as a pound of feathers.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The material being weighed does not affect the weight, only the volume or number of items being weighed."\n}', id='run-e5453bc5-5025-4833-95f9-4967bf6d5c4f-0'), | |
# 'parsed': AnswerWithJustification(answer='A pound of bricks is the same weight as a pound of feathers.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The material being weighed does not affect the weight, only the volume or number of items being weighed.'), | |
# 'parsing_error': None | |
# } | |
Example: JSON mode, no schema (schema=None, method="json_mode", include_raw=True): | |
.. code-block:: | |
from langchain_groq import ChatGroq | |
llm = ChatGroq(temperature=0) | |
structured_llm = llm.with_structured_output(method="json_mode", include_raw=True) | |
structured_llm.invoke( | |
"Answer the following question. " | |
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" | |
"What's heavier a pound of bricks or a pound of feathers?" | |
) | |
# -> { | |
# 'raw': AIMessage(content='{\n "answer": "A pound of bricks is the same weight as a pound of feathers.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The material doesn\'t change the weight, only the volume or space that the material takes up."\n}', id='run-a4abbdb6-c20e-456f-bfff-da906a7e76b5-0'), | |
# 'parsed': { | |
# 'answer': 'A pound of bricks is the same weight as a pound of feathers.', | |
# 'justification': "Both a pound of bricks and a pound of feathers weigh one pound. The material doesn't change the weight, only the volume or space that the material takes up."}, | |
# 'parsing_error': None | |
# } | |
""" # noqa: E501 | |
if kwargs: | |
raise ValueError(f"Received unsupported arguments {kwargs}") | |
is_pydantic_schema = _is_pydantic_class(schema) | |
if method == "function_calling": | |
if schema is None: | |
raise ValueError( | |
"schema must be specified when method is 'function_calling'. " | |
"Received None." | |
) | |
llm = self.bind_tools([schema], tool_choice=True) | |
if is_pydantic_schema: | |
output_parser: OutputParserLike = PydanticToolsParser( | |
tools=[schema], first_tool_only=True | |
) | |
else: | |
key_name = convert_to_openai_tool(schema)["function"]["name"] | |
output_parser = JsonOutputKeyToolsParser( | |
key_name=key_name, first_tool_only=True | |
) | |
elif method == "json_mode": | |
llm = self.bind(response_format={"type": "json_object"}) | |
output_parser = ( | |
PydanticOutputParser(pydantic_object=schema) | |
if is_pydantic_schema | |
else JsonOutputParser() | |
) | |
else: | |
raise ValueError( | |
f"Unrecognized method argument. Expected one of 'function_calling' or " | |
f"'json_format'. Received: '{method}'" | |
) | |
if include_raw: | |
parser_assign = RunnablePassthrough.assign( | |
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None | |
) | |
parser_none = RunnablePassthrough.assign(parsed=lambda _: None) | |
parser_with_fallback = parser_assign.with_fallbacks( | |
[parser_none], exception_key="parsing_error" | |
) | |
return RunnableMap(raw=llm) | parser_with_fallback | |
else: | |
return llm | output_parser | |
def _is_pydantic_class(obj: Any) -> bool: | |
return isinstance(obj, type) and issubclass(obj, BaseModel) | |
class _FunctionCall(TypedDict): | |
name: str | |
# | |
# Type conversion helpers | |
# | |
def _convert_message_to_dict(message: BaseMessage) -> dict: | |
"""Convert a LangChain message to a dictionary. | |
Args: | |
message: The LangChain message. | |
Returns: | |
The dictionary. | |
""" | |
message_dict: Dict[str, Any] | |
if isinstance(message, ChatMessage): | |
message_dict = {"role": message.role, "content": message.content} | |
elif isinstance(message, HumanMessage): | |
message_dict = {"role": "user", "content": message.content} | |
elif isinstance(message, AIMessage): | |
message_dict = {"role": "assistant", "content": message.content} | |
if "function_call" in message.additional_kwargs: | |
message_dict["function_call"] = message.additional_kwargs["function_call"] | |
# If function call only, content is None not empty string | |
if message_dict["content"] == "": | |
message_dict["content"] = None | |
if message.tool_calls or message.invalid_tool_calls: | |
message_dict["tool_calls"] = [ | |
_lc_tool_call_to_groq_tool_call(tc) for tc in message.tool_calls | |
] + [ | |
_lc_invalid_tool_call_to_groq_tool_call(tc) | |
for tc in message.invalid_tool_calls | |
] | |
elif "tool_calls" in message.additional_kwargs: | |
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] | |
# If tool calls only, content is None not empty string | |
if message_dict["content"] == "": | |
message_dict["content"] = None | |
elif isinstance(message, SystemMessage): | |
message_dict = {"role": "system", "content": message.content} | |
elif isinstance(message, FunctionMessage): | |
message_dict = { | |
"role": "function", | |
"content": message.content, | |
"name": message.name, | |
} | |
elif isinstance(message, ToolMessage): | |
message_dict = { | |
"role": "tool", | |
"content": message.content, | |
"tool_call_id": message.tool_call_id, | |
} | |
else: | |
raise TypeError(f"Got unknown type {message}") | |
if "name" in message.additional_kwargs: | |
message_dict["name"] = message.additional_kwargs["name"] | |
return message_dict | |
def _convert_delta_to_message_chunk( | |
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] | |
) -> BaseMessageChunk: | |
role = cast(str, _dict.get("role")) | |
content = cast(str, _dict.get("content") or "") | |
additional_kwargs: Dict = {} | |
if _dict.get("function_call"): | |
function_call = dict(_dict["function_call"]) | |
if "name" in function_call and function_call["name"] is None: | |
function_call["name"] = "" | |
additional_kwargs["function_call"] = function_call | |
if _dict.get("tool_calls"): | |
additional_kwargs["tool_calls"] = _dict["tool_calls"] | |
if role == "user" or default_class == HumanMessageChunk: | |
return HumanMessageChunk(content=content) | |
elif role == "assistant" or default_class == AIMessageChunk: | |
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) | |
elif role == "system" or default_class == SystemMessageChunk: | |
return SystemMessageChunk(content=content) | |
elif role == "function" or default_class == FunctionMessageChunk: | |
return FunctionMessageChunk(content=content, name=_dict["name"]) | |
elif role == "tool" or default_class == ToolMessageChunk: | |
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) | |
elif role or default_class == ChatMessageChunk: | |
return ChatMessageChunk(content=content, role=role) | |
else: | |
return default_class(content=content) # type: ignore | |
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: | |
"""Convert a dictionary to a LangChain message. | |
Args: | |
_dict: The dictionary. | |
Returns: | |
The LangChain message. | |
""" | |
id_ = _dict.get("id") | |
role = _dict.get("role") | |
if role == "user": | |
return HumanMessage(content=_dict.get("content", "")) | |
elif role == "assistant": | |
content = _dict.get("content", "") or "" | |
additional_kwargs: Dict = {} | |
if function_call := _dict.get("function_call"): | |
additional_kwargs["function_call"] = dict(function_call) | |
tool_calls = [] | |
invalid_tool_calls = [] | |
if raw_tool_calls := _dict.get("tool_calls"): | |
additional_kwargs["tool_calls"] = raw_tool_calls | |
for raw_tool_call in raw_tool_calls: | |
try: | |
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) | |
except Exception as e: | |
invalid_tool_calls.append( | |
make_invalid_tool_call(raw_tool_call, str(e)) | |
) | |
return AIMessage( | |
content=content, | |
id=id_, | |
additional_kwargs=additional_kwargs, | |
tool_calls=tool_calls, | |
invalid_tool_calls=invalid_tool_calls, | |
) | |
elif role == "system": | |
return SystemMessage(content=_dict.get("content", "")) | |
elif role == "function": | |
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name")) | |
elif role == "tool": | |
additional_kwargs = {} | |
if "name" in _dict: | |
additional_kwargs["name"] = _dict["name"] | |
return ToolMessage( | |
content=_dict.get("content", ""), | |
tool_call_id=_dict.get("tool_call_id"), | |
additional_kwargs=additional_kwargs, | |
) | |
else: | |
return ChatMessage(content=_dict.get("content", ""), role=role) | |
def _lc_tool_call_to_groq_tool_call(tool_call: ToolCall) -> dict: | |
return { | |
"type": "function", | |
"id": tool_call["id"], | |
"function": { | |
"name": tool_call["name"], | |
"arguments": json.dumps(tool_call["args"]), | |
}, | |
} | |
def _lc_invalid_tool_call_to_groq_tool_call( | |
invalid_tool_call: InvalidToolCall, | |
) -> dict: | |
return { | |
"type": "function", | |
"id": invalid_tool_call["id"], | |
"function": { | |
"name": invalid_tool_call["name"], | |
"arguments": invalid_tool_call["args"], | |
}, | |
} | |