test pr
Browse files- .gitignore +5 -1
- OpenAIChatAtomicFlow.py +9 -10
- OpenAIChatAtomicFlow.yaml +2 -0
.gitignore
CHANGED
@@ -158,4 +158,8 @@ cython_debug/
|
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
.idea/
|
161 |
-
.*cache*/
|
|
|
|
|
|
|
|
|
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
.idea/
|
161 |
+
.*cache*/
|
162 |
+
|
163 |
+
|
164 |
+
# auto-generated by flows, all synced moduels will be ignored by default
|
165 |
+
FLOW_MODULE_ID
|
OpenAIChatAtomicFlow.py
CHANGED
@@ -2,7 +2,7 @@ from copy import deepcopy
|
|
2 |
|
3 |
import hydra
|
4 |
|
5 |
-
import time
|
6 |
|
7 |
from typing import List, Dict, Optional, Any
|
8 |
|
@@ -30,8 +30,6 @@ class OpenAIChatAtomicFlow(AtomicFlow):
|
|
30 |
|
31 |
SUPPORTS_CACHING: bool = True
|
32 |
|
33 |
-
api_keys: Dict[str, str]
|
34 |
-
|
35 |
system_message_prompt_template: PromptTemplate
|
36 |
human_message_prompt_template: PromptTemplate
|
37 |
|
@@ -41,8 +39,8 @@ class OpenAIChatAtomicFlow(AtomicFlow):
|
|
41 |
|
42 |
def __init__(self, **kwargs):
|
43 |
super().__init__(**kwargs)
|
44 |
-
|
45 |
-
self.
|
46 |
|
47 |
assert self.flow_config["name"] not in [
|
48 |
"system",
|
@@ -165,8 +163,7 @@ class OpenAIChatAtomicFlow(AtomicFlow):
|
|
165 |
)
|
166 |
self._log_message(chat_message)
|
167 |
|
168 |
-
def _call(self):
|
169 |
-
api_key = self.api_keys["openai"]
|
170 |
|
171 |
backend = langchain.chat_models.ChatOpenAI(
|
172 |
model_name=self.flow_config["model_name"],
|
@@ -233,14 +230,16 @@ class OpenAIChatAtomicFlow(AtomicFlow):
|
|
233 |
input_data: Dict[str, Any],
|
234 |
private_keys: Optional[List[str]] = [],
|
235 |
keys_to_ignore_for_hash: Optional[List[str]] = []) -> Dict[str, Any]:
|
236 |
-
|
237 |
-
|
|
|
|
|
238 |
|
239 |
# ~~~ Process input ~~~
|
240 |
self._process_input(input_data)
|
241 |
|
242 |
# ~~~ Call ~~~
|
243 |
-
response = self._call()
|
244 |
self._state_update_add_chat_message(
|
245 |
role=self.flow_config["assistant_name"],
|
246 |
content=response
|
|
|
2 |
|
3 |
import hydra
|
4 |
|
5 |
+
import os, time
|
6 |
|
7 |
from typing import List, Dict, Optional, Any
|
8 |
|
|
|
30 |
|
31 |
SUPPORTS_CACHING: bool = True
|
32 |
|
|
|
|
|
33 |
system_message_prompt_template: PromptTemplate
|
34 |
human_message_prompt_template: PromptTemplate
|
35 |
|
|
|
39 |
|
40 |
def __init__(self, **kwargs):
|
41 |
super().__init__(**kwargs)
|
42 |
+
# TODO(yeeef): add documentation, what is guaranteed to be in flow_config?
|
43 |
+
# TODO(yeeef): everything set here as self.xxx=yyy will be invalidated after reset
|
44 |
|
45 |
assert self.flow_config["name"] not in [
|
46 |
"system",
|
|
|
163 |
)
|
164 |
self._log_message(chat_message)
|
165 |
|
166 |
+
def _call(self, api_key: str):
|
|
|
167 |
|
168 |
backend = langchain.chat_models.ChatOpenAI(
|
169 |
model_name=self.flow_config["model_name"],
|
|
|
230 |
input_data: Dict[str, Any],
|
231 |
private_keys: Optional[List[str]] = [],
|
232 |
keys_to_ignore_for_hash: Optional[List[str]] = []) -> Dict[str, Any]:
|
233 |
+
|
234 |
+
api_key = self.flow_config.get("api_key", "")
|
235 |
+
if "api_key" in input_data:
|
236 |
+
api_key = input_data.pop("api_key")
|
237 |
|
238 |
# ~~~ Process input ~~~
|
239 |
self._process_input(input_data)
|
240 |
|
241 |
# ~~~ Call ~~~
|
242 |
+
response = self._call(api_key)
|
243 |
self._state_update_add_chat_message(
|
244 |
role=self.flow_config["assistant_name"],
|
245 |
content=response
|
OpenAIChatAtomicFlow.yaml
CHANGED
@@ -11,6 +11,8 @@ generation_parameters:
|
|
11 |
frequency_penalty: 0
|
12 |
presence_penalty: 0
|
13 |
|
|
|
|
|
14 |
n_api_retries: 6
|
15 |
wait_time_between_retries: 20
|
16 |
|
|
|
11 |
frequency_penalty: 0
|
12 |
presence_penalty: 0
|
13 |
|
14 |
+
api_key: "YOUR_API_KEY"
|
15 |
+
|
16 |
n_api_retries: 6
|
17 |
wait_time_between_retries: 20
|
18 |
|