Spaces:
Running
Running
File size: 2,118 Bytes
2319518 f67d239 2319518 e7de191 2319518 b53a832 2319518 b53a832 a7f9822 2319518 dc8d3c6 2319518 dc8d3c6 2319518 dc8d3c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import os
from typing import Dict, Iterator, List, Optional
import openai
from agent.llm.base import BaseChatModel
from typing import Dict, List, Literal, Optional, Union
class ChatAsOAI(BaseChatModel):
def __init__(self, model: str):
super().__init__()
openai.api_base = os.getenv('OPENAI_API_BASE')
openai.api_key = os.getenv('OPENAI_API_KEY', 'EMPTY')
self.model = os.getenv('OPENAI_MODEL_NAME', model)
def _chat_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> Iterator[str]:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
stop=stop,
stream=True)
# TODO: error handling
for chunk in response:
if hasattr(chunk.choices[0].delta, 'content'):
yield chunk.choices[0].delta.content
def _chat_no_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> str:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
stop=stop,
stream=False)
# TODO: error handling
return response.choices[0].message.content
def chat_with_functions(self,
messages: List[Dict],
functions: Optional[List[Dict]] = None) -> Dict:
if functions:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
functions=functions)
else:
response = openai.ChatCompletion.create(model=self.model,
messages=messages)
# TODO: error handling
return response.choices[0].message
|