import os import typing import json from pydantic import root_validator from langchain.llms import SagemakerEndpoint from langchain.llms.sagemaker_endpoint import LLMContentHandler from src.utils import FakeTokenizer class ChatContentHandler(LLMContentHandler): content_type = "application/json" accepts = "application/json" def transform_input(self, prompt: str, model_kwargs: typing.Dict) -> bytes: messages0 = [] openai_system_prompt = "You are a helpful assistant." if openai_system_prompt: messages0.append({"role": "system", "content": openai_system_prompt}) messages0.append({'role': 'user', 'content': prompt}) input_dict = {'inputs': [messages0], "parameters": model_kwargs} return json.dumps(input_dict).encode("utf-8") def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) return response_json[0]["generation"]['content'] class BaseContentHandler(LLMContentHandler): content_type = "application/json" accepts = "application/json" def transform_input(self, prompt: str, model_kwargs: typing.Dict) -> bytes: input_dict = {'inputs': prompt, "parameters": model_kwargs} return json.dumps(input_dict).encode("utf-8") def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) return response_json[0]["generation"] class H2OSagemakerEndpoint(SagemakerEndpoint): aws_access_key_id: str = "" aws_secret_access_key: str = "" tokenizer: typing.Any = None @root_validator() def validate_environment(cls, values: typing.Dict) -> typing.Dict: """Validate that AWS credentials to and python package exists in environment.""" try: import boto3 try: if values["credentials_profile_name"] is not None: session = boto3.Session( profile_name=values["credentials_profile_name"] ) else: # use default credentials session = boto3.Session() values["client"] = session.client( "sagemaker-runtime", region_name=values['region_name'], aws_access_key_id=values['aws_access_key_id'], aws_secret_access_key=values['aws_secret_access_key'], ) except Exception as e: raise ValueError( "Could not load credentials to authenticate with AWS client. " "Please check that credentials in the specified " "profile name are valid." ) from e except ImportError: raise ImportError( "Could not import boto3 python package. " "Please install it with `pip install boto3`." ) return values def get_token_ids(self, text: str) -> typing.List[int]: tokenizer = self.tokenizer if tokenizer is not None: return tokenizer.encode(text) else: return FakeTokenizer().encode(text)['input_ids']