tastypear commited on
Commit
36c4cac
·
verified ·
1 Parent(s): 6f79ee1

Update cerebras.py

Browse files
Files changed (1) hide show
  1. cerebras.py +132 -59
cerebras.py CHANGED
@@ -1,59 +1,132 @@
1
- # coding:utf-8
2
- import requests
3
- import json
4
- from datetime import datetime, timedelta, timezone
5
-
6
-
7
- class CerebrasUnofficial:
8
- def __init__(self, authjs_session_token: str):
9
- self.api_url = 'https://api.cerebras.ai'
10
- self.authjs_session_token = authjs_session_token
11
- self.key = None
12
- self.expiry = None
13
- self.session = requests.Session()
14
- self.session.headers.update({
15
- 'Content-Type': 'application/json',
16
- 'Cookie': f'authjs.session-token={self.authjs_session_token}',
17
- })
18
-
19
-
20
- def _get_key_from_graphql(self):
21
- json_data = {
22
- 'operationName': 'GetMyDemoApiKey',
23
- 'variables': {},
24
- 'query': 'query GetMyDemoApiKey {\n GetMyDemoApiKey\n}',
25
- }
26
- response = self.session.post(
27
- 'https://inference.cerebras.ai/api/graphql', json=json_data
28
- )
29
- response.raise_for_status()
30
-
31
- data = response.json()
32
- try:
33
- if 'data' in data and 'GetMyDemoApiKey' in data['data']:
34
- self.key = data['data']['GetMyDemoApiKey']
35
- except Exception:
36
- raise Exception('Maybe your authjs.session-token is invalid.')
37
-
38
-
39
- def _get_expiry_from_session(self):
40
- response = self.session.get('https://cloud.cerebras.ai/api/auth/session')
41
- response.raise_for_status()
42
- data = response.json()
43
- if 'user' in data and 'demoApiKeyExpiry' in data['user']:
44
- self.expiry = datetime.fromisoformat(
45
- data['user']['demoApiKeyExpiry'].replace('Z', '+00:00')
46
- )
47
-
48
-
49
- def get_api_key(self) -> str:
50
- if self.key is None:
51
- self._get_key_from_graphql()
52
- else:
53
- if self.expiry is None:
54
- self._get_expiry_from_session()
55
- if datetime.now(timezone.utc) >= self.expiry:
56
- self._get_key_from_graphql()
57
- self.expiry = datetime.now(timezone.utc) + timedelta(minutes=10)
58
-
59
- return self.key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding:utf-8
2
+ import argparse
3
+ import requests
4
+ import json
5
+ import os
6
+ from cerebras import CerebrasUnofficial
7
+ from flask import Flask, request, Response, stream_with_context, jsonify
8
+ import sys
9
+
10
+ # -- Start of Config --
11
+
12
+ # Replace with your cerebras.ai session token found in the `authjs.session-token` cookie.
13
+ # Or you can `set AUTHJS_SESSION_TOKEN=authjs.session-token`
14
+ # This token is valid for one month.
15
+ authjs_session_token = '12345678-abcd-abcd-abcd-12345678abcd'
16
+
17
+ # Replace with any string you wish like `my-api-key`.
18
+ # Or you can `set SERVER_API_KEY=my-api-key`
19
+ # You should set it to update the session token in the future.
20
+ server_api_key = 'my-api-key'
21
+
22
+ # -- End of Config --
23
+
24
+ sys.tracebacklimit = 0
25
+
26
+ authjs_session_token = os.environ.get('AUTHJS_SESSION_TOKEN', authjs_session_token)
27
+ server_api_key = os.environ.get('SERVER_API_KEY', server_api_key)
28
+ print(f'Using the cookie: authjs.session-token={authjs_session_token}')
29
+ print(f'Your api key: {server_api_key}')
30
+
31
+ cerebras_ai = CerebrasUnofficial(authjs_session_token)
32
+
33
+ app = Flask(__name__)
34
+ app.json.sort_keys = False
35
+ parser = argparse.ArgumentParser(description='Cerebras.AI API')
36
+ parser.add_argument('--host', type=str, help='Set the ip address.(default: 0.0.0.0)', default='0.0.0.0')
37
+ parser.add_argument('--port', type=int, help='Set the port.(default: 7860)', default=7860)
38
+ args = parser.parse_args()
39
+
40
+ class Provider:
41
+ key = ''
42
+ max_tokens = None
43
+ api_url = ''
44
+
45
+ def __init__(self, request_key, model):
46
+ self.request_key = request_key
47
+ self.model = model
48
+ self.init_request_info()
49
+
50
+ def init_request_info(self):
51
+ if self.request_key == server_api_key:
52
+ self.api_url = cerebras_ai.api_url
53
+ self.key = cerebras_ai.get_api_key()
54
+
55
+ @app.route('/api', methods=['GET', 'POST'])
56
+ @app.route('/', methods=['GET', 'POST'])
57
+ def index():
58
+ return f'''
59
+ renew/change token by visiting:<br>
60
+ {request.host_url}renew?key={{your server api key}}&token={{your Cerebras authjs_session_token}}<br>
61
+ <br>
62
+ Your interface:<br>
63
+ {request.host_url}v1/chat/completions OR<br>
64
+ {request.host_url}api/v1/chat/completions<br>
65
+ <br>
66
+ For more infomation by visiting:<br>
67
+ https://github.com/tastypear/CerebrasUnofficial
68
+ '''
69
+
70
+ @app.route('/api/renew', methods=['GET', 'POST'])
71
+ @app.route('/renew', methods=['GET', 'POST'])
72
+ def renew_token():
73
+ if server_api_key == request.args.get('key', ''):
74
+ request_token = request.args.get('token', '')
75
+ global cerebras_ai
76
+ cerebras_ai = CerebrasUnofficial(request_token)
77
+ return f'new authjs.session_token: {request_token}'
78
+ else:
79
+ raise Exception('invalid api key')
80
+
81
+ @app.route('/api/v1/models', methods=['GET', 'POST'])
82
+ @app.route('/v1/models', methods=['GET', 'POST'])
83
+ def model_list():
84
+ model_list = {
85
+ 'object': 'list',
86
+ 'data': [{
87
+ 'id': 'llama3.1-8b',
88
+ 'object': 'model',
89
+ 'created': 1721692800,
90
+ 'owned_by': 'Meta'
91
+ }, {
92
+ 'id': 'llama-3.3-70b',
93
+ 'object': 'model',
94
+ 'created': 1733443200,
95
+ 'owned_by': 'Meta'
96
+ }, {
97
+ 'id': 'deepseek-r1-distill-llama-70b',
98
+ 'object': 'model',
99
+ 'created': 1733443200,
100
+ 'owned_by': 'deepseek'
101
+ }]
102
+ }
103
+ return jsonify(model_list)
104
+
105
+
106
+ @app.route('/api/v1/chat/completions', methods=['POST'])
107
+ @app.route('/v1/chat/completions', methods=['POST'])
108
+ def proxy():
109
+ request_key = request.headers['Authorization'].split(' ')[1]
110
+ if server_api_key != request_key:
111
+ raise Exception('invalid api key')
112
+
113
+ headers = dict(request.headers)
114
+ headers.pop('Host', None)
115
+ headers.pop('Content-Length', None)
116
+
117
+ headers['X-Use-Cache'] = 'false'
118
+ model = request.get_json()['model']
119
+ provider = Provider(request_key, model)
120
+ headers['Authorization'] = f'Bearer {provider.key}'
121
+ chat_api = f'{provider.api_url}/v1/chat/completions'
122
+
123
+ def generate():
124
+ with requests.post(chat_api, json=request.json, headers=headers, stream=True) as resp:
125
+ for chunk in resp.iter_content(chunk_size=1024):
126
+ if chunk:
127
+ chunk_str = chunk.decode('utf-8')
128
+ yield chunk_str
129
+ return Response(stream_with_context(generate()), content_type='text/event-stream')
130
+
131
+ if __name__ == '__main__':
132
+ app.run(host=args.host, port=args.port, debug=True)