Spaces:
Runtime error
Runtime error
File size: 5,753 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from __future__ import annotations
import asyncio
import json
from typing import Any, Dict, List, Optional
import aiohttp
import requests
from langchain_core._api.deprecation import deprecated
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
def is_endpoint_live(url: str, headers: Optional[dict], payload: Any) -> bool:
"""
Check if an endpoint is live by sending a GET request to the specified URL.
Args:
url (str): The URL of the endpoint to check.
Returns:
bool: True if the endpoint is live (status code 200), False otherwise.
Raises:
Exception: If the endpoint returns a non-successful status code or if there is
an error querying the endpoint.
"""
try:
response = requests.request("POST", url, headers=headers, data=payload)
# Check if the status code is 200 (OK)
if response.status_code == 200:
return True
else:
# Raise an exception if the status code is not 200
raise Exception(
f"Endpoint returned a non-successful status code: "
f"{response.status_code}"
)
except requests.exceptions.RequestException as e:
# Handle any exceptions (e.g., connection errors)
raise Exception(f"Error querying the endpoint: {e}")
@deprecated(
since="0.0.37",
removal="0.2.0",
message=(
"Directly instantiating a NeMoEmbeddings from langchain-community is "
"deprecated. Please use langchain-nvidia-ai-endpoints NVIDIAEmbeddings "
"interface."
),
)
class NeMoEmbeddings(BaseModel, Embeddings):
"""NeMo embedding models."""
batch_size: int = 16
model: str = "NV-Embed-QA-003"
api_endpoint_url: str = "http://localhost:8088/v1/embeddings"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the end point is alive using the values that are provided."""
url = values["api_endpoint_url"]
model = values["model"]
# Optional: A minimal test payload and headers required by the endpoint
headers = {"Content-Type": "application/json"}
payload = json.dumps(
{"input": "Hello World", "model": model, "input_type": "query"}
)
is_endpoint_live(url, headers, payload)
return values
async def _aembedding_func(
self, session: Any, text: str, input_type: str
) -> List[float]:
"""Async call out to embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
headers = {"Content-Type": "application/json"}
async with session.post(
self.api_endpoint_url,
json={"input": text, "model": self.model, "input_type": input_type},
headers=headers,
) as response:
response.raise_for_status()
answer = await response.text()
answer = json.loads(answer)
return answer["data"][0]["embedding"]
def _embedding_func(self, text: str, input_type: str) -> List[float]:
"""Call out to Cohere's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
payload = json.dumps(
{"input": text, "model": self.model, "input_type": input_type}
)
headers = {"Content-Type": "application/json"}
response = requests.request(
"POST", self.api_endpoint_url, headers=headers, data=payload
)
response_json = json.loads(response.text)
embedding = response_json["data"][0]["embedding"]
return embedding
def embed_documents(self, documents: List[str]) -> List[List[float]]:
"""Embed a list of document texts.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
return [self._embedding_func(text, input_type="passage") for text in documents]
def embed_query(self, text: str) -> List[float]:
return self._embedding_func(text, input_type="query")
async def aembed_query(self, text: str) -> List[float]:
"""Call out to NeMo's embedding endpoint async for embedding query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
async with aiohttp.ClientSession() as session:
embedding = await self._aembedding_func(session, text, "passage")
return embedding
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to NeMo's embedding endpoint async for embedding search docs.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
async with aiohttp.ClientSession() as session:
for batch in range(0, len(texts), self.batch_size):
text_batch = texts[batch : batch + self.batch_size]
for text in text_batch:
# Create tasks for all texts in the batch
tasks = [
self._aembedding_func(session, text, "passage")
for text in text_batch
]
# Run all tasks concurrently
batch_results = await asyncio.gather(*tasks)
# Extend the embeddings list with results from this batch
embeddings.extend(batch_results)
return embeddings
|