leagend commited on
Commit
02142b9
·
verified ·
1 Parent(s): eef5580

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +136 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from copy import deepcopy
5
+ from dataclasses import asdict
6
+ from typing import Dict, List, Union
7
+
8
+ import janus
9
+ from fastapi import FastAPI
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from lagent.schema import AgentStatusCode
12
+ from pydantic import BaseModel
13
+ from sse_starlette.sse import EventSourceResponse
14
+
15
+ from mindsearch.agent import init_agent
16
+
17
+
18
+ def parse_arguments():
19
+ import argparse
20
+ parser = argparse.ArgumentParser(description='MindSearch API')
21
+ parser.add_argument('--lang', default='cn', type=str, help='Language')
22
+ parser.add_argument('--model_format',
23
+ default='internlm_server',
24
+ type=str,
25
+ help='Model format')
26
+ parser.add_argument('--search_engine',
27
+ default='DuckDuckGoSearch',
28
+ type=str,
29
+ help='Search engine')
30
+ return parser.parse_args()
31
+
32
+
33
+ args = parse_arguments()
34
+ app = FastAPI(docs_url='/')
35
+
36
+ app.add_middleware(CORSMiddleware,
37
+ allow_origins=['*'],
38
+ allow_credentials=True,
39
+ allow_methods=['*'],
40
+ allow_headers=['*'])
41
+
42
+
43
+ class GenerationParams(BaseModel):
44
+ inputs: Union[str, List[Dict]]
45
+ agent_cfg: Dict = dict()
46
+
47
+
48
+ @app.post('/solve')
49
+ async def run(request: GenerationParams):
50
+
51
+ def convert_adjacency_to_tree(adjacency_input, root_name):
52
+
53
+ def build_tree(node_name):
54
+ node = {'name': node_name, 'children': []}
55
+ if node_name in adjacency_input:
56
+ for child in adjacency_input[node_name]:
57
+ child_node = build_tree(child['name'])
58
+ child_node['state'] = child['state']
59
+ child_node['id'] = child['id']
60
+ node['children'].append(child_node)
61
+ return node
62
+
63
+ return build_tree(root_name)
64
+
65
+ async def generate():
66
+ try:
67
+ queue = janus.Queue()
68
+ stop_event = asyncio.Event()
69
+
70
+ # Wrapping a sync generator as an async generator using run_in_executor
71
+ def sync_generator_wrapper():
72
+ try:
73
+ for response in agent.stream_chat(inputs):
74
+ queue.sync_q.put(response)
75
+ except Exception as e:
76
+ logging.exception(
77
+ f'Exception in sync_generator_wrapper: {e}')
78
+ finally:
79
+ # Notify async_generator_wrapper that the data generation is complete.
80
+ queue.sync_q.put(None)
81
+
82
+ async def async_generator_wrapper():
83
+ loop = asyncio.get_event_loop()
84
+ loop.run_in_executor(None, sync_generator_wrapper)
85
+ while True:
86
+ response = await queue.async_q.get()
87
+ if response is None: # Ensure that all elements are consumed
88
+ break
89
+ yield response
90
+ if not isinstance(
91
+ response,
92
+ tuple) and response.state == AgentStatusCode.END:
93
+ break
94
+ stop_event.set() # Inform sync_generator_wrapper to stop
95
+
96
+ async for response in async_generator_wrapper():
97
+ if isinstance(response, tuple):
98
+ agent_return, node_name = response
99
+ else:
100
+ agent_return = response
101
+ node_name = None
102
+ origin_adj = deepcopy(agent_return.adjacency_list)
103
+ adjacency_list = convert_adjacency_to_tree(
104
+ agent_return.adjacency_list, 'root')
105
+ assert adjacency_list[
106
+ 'name'] == 'root' and 'children' in adjacency_list
107
+ agent_return.adjacency_list = adjacency_list['children']
108
+ agent_return = asdict(agent_return)
109
+ agent_return['adj'] = origin_adj
110
+ response_json = json.dumps(dict(response=agent_return,
111
+ current_node=node_name),
112
+ ensure_ascii=False)
113
+ yield {'data': response_json}
114
+ # yield f'data: {response_json}\n\n'
115
+ except Exception as exc:
116
+ msg = 'An error occurred while generating the response.'
117
+ logging.exception(msg)
118
+ response_json = json.dumps(
119
+ dict(error=dict(msg=msg, details=str(exc))),
120
+ ensure_ascii=False)
121
+ yield {'data': response_json}
122
+ # yield f'data: {response_json}\n\n'
123
+ finally:
124
+ await stop_event.wait(
125
+ ) # Waiting for async_generator_wrapper to stop
126
+ queue.close()
127
+ await queue.wait_closed()
128
+
129
+ inputs = request.inputs
130
+ agent = init_agent(lang=args.lang, model_format=args.model_format,search_engine=args.search_engine)
131
+ return EventSourceResponse(generate())
132
+
133
+
134
+ if __name__ == '__main__':
135
+ import uvicorn
136
+ uvicorn.run(app, host='0.0.0.0', port=8002, log_level='info')
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ duckduckgo_search==5.3.1b1
2
+ einops
3
+ fastapi
4
+ git+https://github.com/InternLM/lagent.git
5
+ gradio
6
+ janus
7
+ lmdeploy
8
+ pyvis
9
+ sse-starlette
10
+ termcolor
11
+ transformers==4.41.0
12
+ uvicorn