Spaces:
Running
Running
import os | |
from typing import Dict, Iterator, List, Optional | |
import openai | |
from agent.llm.base import BaseChatModel | |
from typing import Dict, List, Literal, Optional, Union | |
class ChatAsOAI(BaseChatModel): | |
def __init__(self, model: str): | |
super().__init__() | |
openai.api_base = os.getenv('OPENAI_API_BASE') | |
openai.api_key = os.getenv('OPENAI_API_KEY', 'EMPTY') | |
self.model = os.getenv('OPENAI_MODEL_NAME', model) | |
def _chat_stream( | |
self, | |
messages: List[Dict], | |
stop: Optional[List[str]] = None, | |
) -> Iterator[str]: | |
response = openai.ChatCompletion.create(model=self.model, | |
messages=messages, | |
stop=stop, | |
stream=True) | |
# TODO: error handling | |
for chunk in response: | |
if hasattr(chunk.choices[0].delta, 'content'): | |
yield chunk.choices[0].delta.content | |
def _chat_no_stream( | |
self, | |
messages: List[Dict], | |
stop: Optional[List[str]] = None, | |
) -> str: | |
response = openai.ChatCompletion.create(model=self.model, | |
messages=messages, | |
stop=stop, | |
stream=False) | |
# TODO: error handling | |
return response.choices[0].message.content | |
def chat_with_functions(self, | |
messages: List[Dict], | |
functions: Optional[List[Dict]] = None) -> Dict: | |
if functions: | |
response = openai.ChatCompletion.create(model=self.model, | |
messages=messages, | |
functions=functions) | |
else: | |
response = openai.ChatCompletion.create(model=self.model, | |
messages=messages) | |
# TODO: error handling | |
return response.choices[0].message | |