Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import asyncio | |
import json | |
from typing import Any, Dict, List, Optional | |
import aiohttp | |
import requests | |
from langchain_core._api.deprecation import deprecated | |
from langchain_core.embeddings import Embeddings | |
from langchain_core.pydantic_v1 import BaseModel, root_validator | |
def is_endpoint_live(url: str, headers: Optional[dict], payload: Any) -> bool: | |
""" | |
Check if an endpoint is live by sending a GET request to the specified URL. | |
Args: | |
url (str): The URL of the endpoint to check. | |
Returns: | |
bool: True if the endpoint is live (status code 200), False otherwise. | |
Raises: | |
Exception: If the endpoint returns a non-successful status code or if there is | |
an error querying the endpoint. | |
""" | |
try: | |
response = requests.request("POST", url, headers=headers, data=payload) | |
# Check if the status code is 200 (OK) | |
if response.status_code == 200: | |
return True | |
else: | |
# Raise an exception if the status code is not 200 | |
raise Exception( | |
f"Endpoint returned a non-successful status code: " | |
f"{response.status_code}" | |
) | |
except requests.exceptions.RequestException as e: | |
# Handle any exceptions (e.g., connection errors) | |
raise Exception(f"Error querying the endpoint: {e}") | |
class NeMoEmbeddings(BaseModel, Embeddings): | |
"""NeMo embedding models.""" | |
batch_size: int = 16 | |
model: str = "NV-Embed-QA-003" | |
api_endpoint_url: str = "http://localhost:8088/v1/embeddings" | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that the end point is alive using the values that are provided.""" | |
url = values["api_endpoint_url"] | |
model = values["model"] | |
# Optional: A minimal test payload and headers required by the endpoint | |
headers = {"Content-Type": "application/json"} | |
payload = json.dumps( | |
{"input": "Hello World", "model": model, "input_type": "query"} | |
) | |
is_endpoint_live(url, headers, payload) | |
return values | |
async def _aembedding_func( | |
self, session: Any, text: str, input_type: str | |
) -> List[float]: | |
"""Async call out to embedding endpoint. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embeddings for the text. | |
""" | |
headers = {"Content-Type": "application/json"} | |
async with session.post( | |
self.api_endpoint_url, | |
json={"input": text, "model": self.model, "input_type": input_type}, | |
headers=headers, | |
) as response: | |
response.raise_for_status() | |
answer = await response.text() | |
answer = json.loads(answer) | |
return answer["data"][0]["embedding"] | |
def _embedding_func(self, text: str, input_type: str) -> List[float]: | |
"""Call out to Cohere's embedding endpoint. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embeddings for the text. | |
""" | |
payload = json.dumps( | |
{"input": text, "model": self.model, "input_type": input_type} | |
) | |
headers = {"Content-Type": "application/json"} | |
response = requests.request( | |
"POST", self.api_endpoint_url, headers=headers, data=payload | |
) | |
response_json = json.loads(response.text) | |
embedding = response_json["data"][0]["embedding"] | |
return embedding | |
def embed_documents(self, documents: List[str]) -> List[List[float]]: | |
"""Embed a list of document texts. | |
Args: | |
texts: The list of texts to embed. | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
return [self._embedding_func(text, input_type="passage") for text in documents] | |
def embed_query(self, text: str) -> List[float]: | |
return self._embedding_func(text, input_type="query") | |
async def aembed_query(self, text: str) -> List[float]: | |
"""Call out to NeMo's embedding endpoint async for embedding query text. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embedding for the text. | |
""" | |
async with aiohttp.ClientSession() as session: | |
embedding = await self._aembedding_func(session, text, "passage") | |
return embedding | |
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Call out to NeMo's embedding endpoint async for embedding search docs. | |
Args: | |
texts: The list of texts to embed. | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
embeddings = [] | |
async with aiohttp.ClientSession() as session: | |
for batch in range(0, len(texts), self.batch_size): | |
text_batch = texts[batch : batch + self.batch_size] | |
for text in text_batch: | |
# Create tasks for all texts in the batch | |
tasks = [ | |
self._aembedding_func(session, text, "passage") | |
for text in text_batch | |
] | |
# Run all tasks concurrently | |
batch_results = await asyncio.gather(*tasks) | |
# Extend the embeddings list with results from this batch | |
embeddings.extend(batch_results) | |
return embeddings | |