Spaces:
Running
Running
Lohia, Aditya
commited on
Commit
·
121f84a
1
Parent(s):
2e70021
add: gateway to interact with the api
Browse files- gateway.py +94 -0
gateway.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import requests
|
3 |
+
|
4 |
+
|
5 |
+
def check_server_health(cloud_gateway_api: str):
|
6 |
+
"""
|
7 |
+
Use the appropriate API endpoint to check the server health.
|
8 |
+
Args:
|
9 |
+
cloud_gateway_api: API endpoint to probe.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
True if server is active, false otherwise.
|
13 |
+
"""
|
14 |
+
try:
|
15 |
+
response = requests.get(cloud_gateway_api + "/health")
|
16 |
+
if response.status_code == 200:
|
17 |
+
return True
|
18 |
+
except requests.ConnectionError:
|
19 |
+
print("Failed to establish connection to the server.")
|
20 |
+
|
21 |
+
return False
|
22 |
+
|
23 |
+
|
24 |
+
def request_generation(
|
25 |
+
message: str,
|
26 |
+
system_prompt: str,
|
27 |
+
cloud_gateway_api: str,
|
28 |
+
max_new_tokens: int = 1024,
|
29 |
+
temperature: float = 0.6,
|
30 |
+
top_p: float = 0.9,
|
31 |
+
top_k: int = 50,
|
32 |
+
repetition_penalty: float = 1.2,
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
|
36 |
+
token-by-token generation from LLM.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
message: prompt from the user.
|
40 |
+
system_prompt: system prompt to append.
|
41 |
+
cloud_gateway_api (str): API endpoint to send the request.
|
42 |
+
max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
|
43 |
+
temperature: the value used to module the next token probabilities.
|
44 |
+
top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
|
45 |
+
or higher are kept for generation.
|
46 |
+
top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering.
|
47 |
+
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
|
51 |
+
"""
|
52 |
+
|
53 |
+
payload = {
|
54 |
+
"model": "google/gemma-3-27b-it",
|
55 |
+
"messages": [
|
56 |
+
{"role": "system", "content": system_prompt},
|
57 |
+
{"role": "user", "content": message},
|
58 |
+
],
|
59 |
+
"max_tokens": max_new_tokens,
|
60 |
+
"temperature": temperature,
|
61 |
+
"top_p": top_p,
|
62 |
+
"repetition_penalty": repetition_penalty,
|
63 |
+
"top_k": top_k,
|
64 |
+
"stream": True, # Enable streaming
|
65 |
+
}
|
66 |
+
|
67 |
+
with requests.post(
|
68 |
+
cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True
|
69 |
+
) as response:
|
70 |
+
for chunk in response.iter_lines():
|
71 |
+
if chunk:
|
72 |
+
# Convert the chunk from bytes to a string and then parse it as json
|
73 |
+
chunk_str = chunk.decode("utf-8")
|
74 |
+
|
75 |
+
# Remove the `data: ` prefix from the chunk if it exists
|
76 |
+
if chunk_str.startswith("data: "):
|
77 |
+
chunk_str = chunk_str[len("data: ") :]
|
78 |
+
|
79 |
+
# Skip empty chunks
|
80 |
+
if chunk_str.strip() == "[DONE]":
|
81 |
+
break
|
82 |
+
|
83 |
+
# Parse the chunk into a JSON object
|
84 |
+
try:
|
85 |
+
chunk_json = json.loads(chunk_str)
|
86 |
+
# Extract the "content" field from the choices
|
87 |
+
content = chunk_json["choices"][0]["delta"].get("content", "")
|
88 |
+
|
89 |
+
# Print the generated content as it's streamed
|
90 |
+
if content:
|
91 |
+
yield content
|
92 |
+
except json.JSONDecodeError:
|
93 |
+
# Handle any potential errors in decoding
|
94 |
+
continue
|