awacke1 commited on
Commit
9db9a61
·
1 Parent(s): 57d6629

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -31
app.py CHANGED
@@ -95,37 +95,43 @@ def add_witty_humor_buttons():
95
 
96
  # Function to Stream Inference Client for Inference Endpoint Responses
97
  def StreamLLMChatResponse(prompt):
98
- endpoint_url = API_URL
99
- hf_token = API_KEY
100
- client = InferenceClient(endpoint_url, token=hf_token)
101
- gen_kwargs = dict(
102
- max_new_tokens=512,
103
- top_k=30,
104
- top_p=0.9,
105
- temperature=0.2,
106
- repetition_penalty=1.02,
107
- stop_sequences=["\nUser:", "<|endoftext|>", "</s>"],
108
- )
109
- stream = client.text_generation(prompt, stream=True, details=True, **gen_kwargs)
110
- report=[]
111
- res_box = st.empty()
112
- collected_chunks=[]
113
- collected_messages=[]
114
- for r in stream:
115
- if r.token.special:
116
- continue
117
- if r.token.text in gen_kwargs["stop_sequences"]:
118
- break
119
- collected_chunks.append(r.token.text)
120
- chunk_message = r.token.text
121
- collected_messages.append(chunk_message)
122
- try:
123
- report.append(r.token.text)
124
- if len(r.token.text) > 0:
125
- result="".join(report).strip()
126
- res_box.markdown(f'*{result}*')
127
- except:
128
- st.write(' ')
 
 
 
 
 
 
129
 
130
  def query(payload):
131
  response = requests.post(API_URL, headers=headers, json=payload)
 
95
 
96
  # Function to Stream Inference Client for Inference Endpoint Responses
97
  def StreamLLMChatResponse(prompt):
98
+
99
+ try:
100
+ endpoint_url = API_URL
101
+ hf_token = API_KEY
102
+ client = InferenceClient(endpoint_url, token=hf_token)
103
+ gen_kwargs = dict(
104
+ max_new_tokens=512,
105
+ top_k=30,
106
+ top_p=0.9,
107
+ temperature=0.2,
108
+ repetition_penalty=1.02,
109
+ stop_sequences=["\nUser:", "<|endoftext|>", "</s>"],
110
+ )
111
+ stream = client.text_generation(prompt, stream=True, details=True, **gen_kwargs)
112
+ report=[]
113
+ res_box = st.empty()
114
+ collected_chunks=[]
115
+ collected_messages=[]
116
+ for r in stream:
117
+ if r.token.special:
118
+ continue
119
+ if r.token.text in gen_kwargs["stop_sequences"]:
120
+ break
121
+ collected_chunks.append(r.token.text)
122
+ chunk_message = r.token.text
123
+ collected_messages.append(chunk_message)
124
+ try:
125
+ report.append(r.token.text)
126
+ if len(r.token.text) > 0:
127
+ result="".join(report).strip()
128
+ res_box.markdown(f'*{result}*')
129
+ except:
130
+ st.write(' ')
131
+ except:
132
+ st.write('DromeLlama is asleep. Starting up now on A10 - please give 5 minutes then retry as KEDA scales up from zero to activate running container(s).')
133
+
134
+
135
 
136
  def query(payload):
137
  response = requests.post(API_URL, headers=headers, json=payload)