Spaces:
Runtime error
Runtime error
import asyncio | |
import json | |
import os | |
from typing import Any, Dict, List, Optional | |
import numpy as np | |
from langchain_core.embeddings import Embeddings | |
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator | |
from langchain_core.runnables.config import run_in_executor | |
class BedrockEmbeddings(BaseModel, Embeddings): | |
"""Bedrock embedding models. | |
To authenticate, the AWS client uses the following methods to | |
automatically load credentials: | |
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | |
If a specific credential profile should be used, you must pass | |
the name of the profile from the ~/.aws/credentials file that is to be used. | |
Make sure the credentials / roles used have the required policies to | |
access the Bedrock service. | |
""" | |
""" | |
Example: | |
.. code-block:: python | |
from langchain_community.bedrock_embeddings import BedrockEmbeddings | |
region_name ="us-east-1" | |
credentials_profile_name = "default" | |
model_id = "amazon.titan-embed-text-v1" | |
be = BedrockEmbeddings( | |
credentials_profile_name=credentials_profile_name, | |
region_name=region_name, | |
model_id=model_id | |
) | |
""" | |
client: Any #: :meta private: | |
"""Bedrock client.""" | |
region_name: Optional[str] = None | |
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable | |
or region specified in ~/.aws/config in case it is not provided here. | |
""" | |
credentials_profile_name: Optional[str] = None | |
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which | |
has either access keys or role information specified. | |
If not specified, the default credential profile or, if on an EC2 instance, | |
credentials from IMDS will be used. | |
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | |
""" | |
model_id: str = "amazon.titan-embed-text-v1" | |
"""Id of the model to call, e.g., amazon.titan-embed-text-v1, this is | |
equivalent to the modelId property in the list-foundation-models api""" | |
model_kwargs: Optional[Dict] = None | |
"""Keyword arguments to pass to the model.""" | |
endpoint_url: Optional[str] = None | |
"""Needed if you don't want to default to us-east-1 endpoint""" | |
normalize: bool = False | |
"""Whether the embeddings should be normalized to unit vectors""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that AWS credentials to and python package exists in environment.""" | |
if values["client"] is not None: | |
return values | |
try: | |
import boto3 | |
if values["credentials_profile_name"] is not None: | |
session = boto3.Session(profile_name=values["credentials_profile_name"]) | |
else: | |
# use default credentials | |
session = boto3.Session() | |
client_params = {} | |
if values["region_name"]: | |
client_params["region_name"] = values["region_name"] | |
if values["endpoint_url"]: | |
client_params["endpoint_url"] = values["endpoint_url"] | |
values["client"] = session.client("bedrock-runtime", **client_params) | |
except ImportError: | |
raise ImportError( | |
"Could not import boto3 python package. " | |
"Please install it with `pip install boto3`." | |
) | |
except Exception as e: | |
raise ValueError( | |
"Could not load credentials to authenticate with AWS client. " | |
"Please check that credentials in the specified " | |
f"profile name are valid. Bedrock error: {e}" | |
) from e | |
return values | |
def _embedding_func(self, text: str) -> List[float]: | |
"""Call out to Bedrock embedding endpoint.""" | |
# replace newlines, which can negatively affect performance. | |
text = text.replace(os.linesep, " ") | |
# format input body for provider | |
provider = self.model_id.split(".")[0] | |
_model_kwargs = self.model_kwargs or {} | |
input_body = {**_model_kwargs} | |
if provider == "cohere": | |
if "input_type" not in input_body.keys(): | |
input_body["input_type"] = "search_document" | |
input_body["texts"] = [text] | |
else: | |
# includes common provider == "amazon" | |
input_body["inputText"] = text | |
body = json.dumps(input_body) | |
try: | |
# invoke bedrock API | |
response = self.client.invoke_model( | |
body=body, | |
modelId=self.model_id, | |
accept="application/json", | |
contentType="application/json", | |
) | |
# format output based on provider | |
response_body = json.loads(response.get("body").read()) | |
if provider == "cohere": | |
return response_body.get("embeddings")[0] | |
else: | |
# includes common provider == "amazon" | |
return response_body.get("embedding") | |
except Exception as e: | |
raise ValueError(f"Error raised by inference endpoint: {e}") | |
def _normalize_vector(self, embeddings: List[float]) -> List[float]: | |
"""Normalize the embedding to a unit vector.""" | |
emb = np.array(embeddings) | |
norm_emb = emb / np.linalg.norm(emb) | |
return norm_emb.tolist() | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Compute doc embeddings using a Bedrock model. | |
Args: | |
texts: The list of texts to embed | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
results = [] | |
for text in texts: | |
response = self._embedding_func(text) | |
if self.normalize: | |
response = self._normalize_vector(response) | |
results.append(response) | |
return results | |
def embed_query(self, text: str) -> List[float]: | |
"""Compute query embeddings using a Bedrock model. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embeddings for the text. | |
""" | |
embedding = self._embedding_func(text) | |
if self.normalize: | |
return self._normalize_vector(embedding) | |
return embedding | |
async def aembed_query(self, text: str) -> List[float]: | |
"""Asynchronous compute query embeddings using a Bedrock model. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embeddings for the text. | |
""" | |
return await run_in_executor(None, self.embed_query, text) | |
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Asynchronous compute doc embeddings using a Bedrock model. | |
Args: | |
texts: The list of texts to embed | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
result = await asyncio.gather(*[self.aembed_query(text) for text in texts]) | |
return list(result) | |