Spaces:
Runtime error
Runtime error
import json | |
import os | |
import uuid | |
from typing import Optional, List, Dict | |
import requests | |
from chainlit import PersistedUser | |
from chainlit.data import BaseDataLayer | |
from chainlit.types import PageInfo, ThreadFilter, ThreadDict, Pagination, PaginatedResponse | |
from literalai.helper import utc_now | |
fastgpt_base_url = os.getenv("FASTGPT_BASE_URL") | |
share_id = os.getenv("FASTGPT_SHARE_ID") | |
now = utc_now() | |
user_cur_threads = [] | |
thread_user_dict = {} | |
def change_type(user_type: str): | |
if user_type == 'AI': | |
return 'assistant_message' | |
if user_type == 'Human': | |
return 'user_message' | |
def get_app_info(): | |
with requests.get(f"{fastgpt_base_url}/api/core/chat/outLink/init?chatId=&shareId={share_id}&outLinkUid=123456") as resp: | |
app = {} | |
if resp.status_code == 200: | |
res = json.loads(resp.content) | |
app = res.get('data').get('app') | |
appId = res.get('data').get('appId') | |
app['id'] = appId | |
return app | |
app_info = get_app_info() | |
app_id = app_info.get('id') | |
app_name = app_info.get('name') | |
welcome_text = app_info.get('chatConfig').get('welcomeText') | |
def getHistories(user_id): | |
histories = [] | |
if user_id: | |
with requests.post(f"{fastgpt_base_url}/api/core/chat/getHistories", data={"shareId": share_id, "outLinkUid": user_id}) as resp: | |
if resp.status_code == 200: | |
res = json.loads(resp.content) | |
data = res["data"] | |
print(data) | |
histories = [{"id": item["chatId"], "name": item["title"], "createdAt": item["updateTime"], "userId": user_id, "userIdentifier": user_id} for item in data] | |
if user_cur_threads: | |
thread = next((t for t in user_cur_threads if t["userId"] == user_id), None) | |
if thread: # 确保 thread 不为 None | |
thread_id = thread.get("id") | |
if histories: | |
# 检查 thread 的 ID 是否已存在于 threads 中 | |
if not any(t.get("id") == thread_id for t in histories): | |
histories.insert(0, thread) | |
else: | |
# 如果 threads 是空列表,则直接插入 thread | |
histories.insert(0, thread) | |
for item in histories: | |
thread_user_dict[item.get('id')] = item.get('userId') | |
return histories | |
class FastgptDataLayer(BaseDataLayer): | |
async def get_user(self, identifier: str): | |
print('get_user', identifier) | |
return PersistedUser(id=identifier, createdAt=now, identifier=identifier) | |
async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None,): | |
print('---------update_thread----------', thread_id) | |
thread = next((t for t in user_cur_threads if t["userId"] == user_id), None) | |
if thread: | |
if thread_id: | |
thread["id"] = thread_id | |
if name: | |
thread["name"] = name | |
if user_id: | |
thread["userId"] = user_id | |
thread["userIdentifier"] = user_id | |
if metadata: | |
thread["metadata"] = metadata | |
if tags: | |
thread["tags"] = tags | |
thread["createdAt"] = utc_now() | |
else: | |
print('---------update_thread----------thread_id ', thread_id, name) | |
user_cur_threads.append({"id": thread_id, "name": name, "metadata": metadata, "tags": tags, "createdAt": utc_now(), "userId": user_id, "userIdentifier": user_id,}) | |
async def get_thread_author(self, thread_id: str): | |
print('get_thread_author') | |
return thread_user_dict.get(thread_id, None) | |
async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse[ThreadDict]: | |
threads = [] | |
if filters: | |
threads = getHistories(filters.userId) | |
search = filters.search if filters.search else "" | |
filtered_threads = [thread for thread in threads if search in thread.get('name', '')] | |
start = 0 | |
if pagination.cursor: | |
for i, thread in enumerate(filtered_threads): | |
if thread["id"] == pagination.cursor: # Find the start index using pagination.cursor | |
start = i + 1 | |
break | |
end = start + pagination.first | |
paginated_threads = filtered_threads[start:end] or [] | |
has_next_page = len(paginated_threads) > end | |
start_cursor = paginated_threads[0]["id"] if paginated_threads else None | |
end_cursor = paginated_threads[-1]["id"] if paginated_threads else None | |
return PaginatedResponse(pageInfo=PageInfo(hasNextPage=has_next_page, startCursor=start_cursor, endCursor=end_cursor,), data=paginated_threads,) | |
async def get_thread(self, thread_id: str): | |
print('get_thread', thread_id) | |
user_id = thread_user_dict.get(thread_id, None) | |
thread = None | |
if user_id: | |
params = {'chatId': thread_id, 'shareId': share_id, 'outLinkUid': user_id,} | |
with requests.get(f"{fastgpt_base_url}/api/core/chat/outLink/init", params=params,) as resp: | |
if resp.status_code == 200: | |
res = json.loads(resp.content) | |
data = res["data"] | |
if data: | |
history = data['history'] | |
files = [] | |
texts = [] | |
for item in history: | |
for entry in item['value']: | |
if entry.get('type') == 'text': | |
text = {"id": item["_id"], "threadId": thread_id, "name": item["obj"], "type": change_type(item["obj"]), "input": None, "createdAt": utc_now(), "output": entry.get('text').get('content'),} | |
texts.append(text) | |
if entry.get('type') == 'file': | |
file = {"id": str(uuid.UUID), "threadId": thread_id, "forId": item["_id"], "name": entry.get('file').get('name'), "type": entry.get('file').get('type'), "url": entry.get('file').get('url'), "display": "inline", "size": "medium"} | |
files.append(file) | |
thread = {"id": thread_id, "name": data.get("title", ''), "createdAt": utc_now(), "userId": "admin", "userIdentifier": "admin", "metadata": {"appId": data["appId"]}, "steps": texts, "elements": files,} | |
return thread | |
async def delete_thread(self, thread_id: str): | |
print('delete_thread') | |
thread = next((t for t in user_cur_threads if t["id"] == thread_id), None) | |
user_id = thread_user_dict.get(thread_id, None) | |
if thread: | |
user_cur_threads.remove(thread) | |
if user_id: | |
params = {'appId': app_id, 'chatId': thread_id, 'shareId': share_id, 'outLinkUid': user_id,} | |
requests.get(f"{fastgpt_base_url}/api/core/chat/delHistory", params=params) | |