chatpregdrug / fastgpt_data.py
leonsimon23's picture
Create fastgpt_data.py
ab157ca verified
import json
import os
import uuid
from typing import Optional, List, Dict
import requests
from chainlit import PersistedUser
from chainlit.data import BaseDataLayer
from chainlit.types import PageInfo, ThreadFilter, ThreadDict, Pagination, PaginatedResponse
from literalai.helper import utc_now
fastgpt_base_url = os.getenv("FASTGPT_BASE_URL")
share_id = os.getenv("FASTGPT_SHARE_ID")
now = utc_now()
user_cur_threads = []
thread_user_dict = {}
def change_type(user_type: str):
if user_type == 'AI':
return 'assistant_message'
if user_type == 'Human':
return 'user_message'
def get_app_info():
with requests.get(f"{fastgpt_base_url}/api/core/chat/outLink/init?chatId=&shareId={share_id}&outLinkUid=123456") as resp:
app = {}
if resp.status_code == 200:
res = json.loads(resp.content)
app = res.get('data').get('app')
appId = res.get('data').get('appId')
app['id'] = appId
return app
app_info = get_app_info()
app_id = app_info.get('id')
app_name = app_info.get('name')
welcome_text = app_info.get('chatConfig').get('welcomeText')
def getHistories(user_id):
histories = []
if user_id:
with requests.post(f"{fastgpt_base_url}/api/core/chat/getHistories", data={"shareId": share_id, "outLinkUid": user_id}) as resp:
if resp.status_code == 200:
res = json.loads(resp.content)
data = res["data"]
print(data)
histories = [{"id": item["chatId"], "name": item["title"], "createdAt": item["updateTime"], "userId": user_id, "userIdentifier": user_id} for item in data]
if user_cur_threads:
thread = next((t for t in user_cur_threads if t["userId"] == user_id), None)
if thread: # 确保 thread 不为 None
thread_id = thread.get("id")
if histories:
# 检查 thread 的 ID 是否已存在于 threads 中
if not any(t.get("id") == thread_id for t in histories):
histories.insert(0, thread)
else:
# 如果 threads 是空列表,则直接插入 thread
histories.insert(0, thread)
for item in histories:
thread_user_dict[item.get('id')] = item.get('userId')
return histories
class FastgptDataLayer(BaseDataLayer):
async def get_user(self, identifier: str):
print('get_user', identifier)
return PersistedUser(id=identifier, createdAt=now, identifier=identifier)
async def update_thread(self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None,):
print('---------update_thread----------', thread_id)
thread = next((t for t in user_cur_threads if t["userId"] == user_id), None)
if thread:
if thread_id:
thread["id"] = thread_id
if name:
thread["name"] = name
if user_id:
thread["userId"] = user_id
thread["userIdentifier"] = user_id
if metadata:
thread["metadata"] = metadata
if tags:
thread["tags"] = tags
thread["createdAt"] = utc_now()
else:
print('---------update_thread----------thread_id ', thread_id, name)
user_cur_threads.append({"id": thread_id, "name": name, "metadata": metadata, "tags": tags, "createdAt": utc_now(), "userId": user_id, "userIdentifier": user_id,})
async def get_thread_author(self, thread_id: str):
print('get_thread_author')
return thread_user_dict.get(thread_id, None)
async def list_threads(self, pagination: Pagination, filters: ThreadFilter) -> PaginatedResponse[ThreadDict]:
threads = []
if filters:
threads = getHistories(filters.userId)
search = filters.search if filters.search else ""
filtered_threads = [thread for thread in threads if search in thread.get('name', '')]
start = 0
if pagination.cursor:
for i, thread in enumerate(filtered_threads):
if thread["id"] == pagination.cursor: # Find the start index using pagination.cursor
start = i + 1
break
end = start + pagination.first
paginated_threads = filtered_threads[start:end] or []
has_next_page = len(paginated_threads) > end
start_cursor = paginated_threads[0]["id"] if paginated_threads else None
end_cursor = paginated_threads[-1]["id"] if paginated_threads else None
return PaginatedResponse(pageInfo=PageInfo(hasNextPage=has_next_page, startCursor=start_cursor, endCursor=end_cursor,), data=paginated_threads,)
async def get_thread(self, thread_id: str):
print('get_thread', thread_id)
user_id = thread_user_dict.get(thread_id, None)
thread = None
if user_id:
params = {'chatId': thread_id, 'shareId': share_id, 'outLinkUid': user_id,}
with requests.get(f"{fastgpt_base_url}/api/core/chat/outLink/init", params=params,) as resp:
if resp.status_code == 200:
res = json.loads(resp.content)
data = res["data"]
if data:
history = data['history']
files = []
texts = []
for item in history:
for entry in item['value']:
if entry.get('type') == 'text':
text = {"id": item["_id"], "threadId": thread_id, "name": item["obj"], "type": change_type(item["obj"]), "input": None, "createdAt": utc_now(), "output": entry.get('text').get('content'),}
texts.append(text)
if entry.get('type') == 'file':
file = {"id": str(uuid.UUID), "threadId": thread_id, "forId": item["_id"], "name": entry.get('file').get('name'), "type": entry.get('file').get('type'), "url": entry.get('file').get('url'), "display": "inline", "size": "medium"}
files.append(file)
thread = {"id": thread_id, "name": data.get("title", ''), "createdAt": utc_now(), "userId": "admin", "userIdentifier": "admin", "metadata": {"appId": data["appId"]}, "steps": texts, "elements": files,}
return thread
async def delete_thread(self, thread_id: str):
print('delete_thread')
thread = next((t for t in user_cur_threads if t["id"] == thread_id), None)
user_id = thread_user_dict.get(thread_id, None)
if thread:
user_cur_threads.remove(thread)
if user_id:
params = {'appId': app_id, 'chatId': thread_id, 'shareId': share_id, 'outLinkUid': user_id,}
requests.get(f"{fastgpt_base_url}/api/core/chat/delHistory", params=params)