Lohia, Aditya commited on
Commit
121f84a
·
1 Parent(s): 2e70021

add: gateway to interact with the api

Browse files
Files changed (1) hide show
  1. gateway.py +94 -0
gateway.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+
4
+
5
+ def check_server_health(cloud_gateway_api: str):
6
+ """
7
+ Use the appropriate API endpoint to check the server health.
8
+ Args:
9
+ cloud_gateway_api: API endpoint to probe.
10
+
11
+ Returns:
12
+ True if server is active, false otherwise.
13
+ """
14
+ try:
15
+ response = requests.get(cloud_gateway_api + "/health")
16
+ if response.status_code == 200:
17
+ return True
18
+ except requests.ConnectionError:
19
+ print("Failed to establish connection to the server.")
20
+
21
+ return False
22
+
23
+
24
+ def request_generation(
25
+ message: str,
26
+ system_prompt: str,
27
+ cloud_gateway_api: str,
28
+ max_new_tokens: int = 1024,
29
+ temperature: float = 0.6,
30
+ top_p: float = 0.9,
31
+ top_k: int = 50,
32
+ repetition_penalty: float = 1.2,
33
+ ):
34
+ """
35
+ Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
36
+ token-by-token generation from LLM.
37
+
38
+ Args:
39
+ message: prompt from the user.
40
+ system_prompt: system prompt to append.
41
+ cloud_gateway_api (str): API endpoint to send the request.
42
+ max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
43
+ temperature: the value used to module the next token probabilities.
44
+ top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
45
+ or higher are kept for generation.
46
+ top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering.
47
+ repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
48
+
49
+ Returns:
50
+
51
+ """
52
+
53
+ payload = {
54
+ "model": "google/gemma-3-27b-it",
55
+ "messages": [
56
+ {"role": "system", "content": system_prompt},
57
+ {"role": "user", "content": message},
58
+ ],
59
+ "max_tokens": max_new_tokens,
60
+ "temperature": temperature,
61
+ "top_p": top_p,
62
+ "repetition_penalty": repetition_penalty,
63
+ "top_k": top_k,
64
+ "stream": True, # Enable streaming
65
+ }
66
+
67
+ with requests.post(
68
+ cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True
69
+ ) as response:
70
+ for chunk in response.iter_lines():
71
+ if chunk:
72
+ # Convert the chunk from bytes to a string and then parse it as json
73
+ chunk_str = chunk.decode("utf-8")
74
+
75
+ # Remove the `data: ` prefix from the chunk if it exists
76
+ if chunk_str.startswith("data: "):
77
+ chunk_str = chunk_str[len("data: ") :]
78
+
79
+ # Skip empty chunks
80
+ if chunk_str.strip() == "[DONE]":
81
+ break
82
+
83
+ # Parse the chunk into a JSON object
84
+ try:
85
+ chunk_json = json.loads(chunk_str)
86
+ # Extract the "content" field from the choices
87
+ content = chunk_json["choices"][0]["delta"].get("content", "")
88
+
89
+ # Print the generated content as it's streamed
90
+ if content:
91
+ yield content
92
+ except json.JSONDecodeError:
93
+ # Handle any potential errors in decoding
94
+ continue