ff_li
oai
e7de191
import os
from typing import Dict, Iterator, List, Optional
import openai
from agent.llm.base import BaseChatModel
from typing import Dict, List, Literal, Optional, Union
class ChatAsOAI(BaseChatModel):
def __init__(self, model: str):
super().__init__()
openai.api_base = os.getenv('OPENAI_API_BASE')
openai.api_key = os.getenv('OPENAI_API_KEY', 'EMPTY')
self.model = os.getenv('OPENAI_MODEL_NAME', model)
def _chat_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> Iterator[str]:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
stop=stop,
stream=True)
# TODO: error handling
for chunk in response:
if hasattr(chunk.choices[0].delta, 'content'):
yield chunk.choices[0].delta.content
def _chat_no_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> str:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
stop=stop,
stream=False)
# TODO: error handling
return response.choices[0].message.content
def chat_with_functions(self,
messages: List[Dict],
functions: Optional[List[Dict]] = None) -> Dict:
if functions:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
functions=functions)
else:
response = openai.ChatCompletion.create(model=self.model,
messages=messages)
# TODO: error handling
return response.choices[0].message