alfraser commited on
Commit
0adaf44
·
1 Parent(s): b3911a2

Added user messaging when the endpoint is down.

Browse files
Files changed (1) hide show
  1. src/architectures.py +32 -9
src/architectures.py CHANGED
@@ -279,13 +279,29 @@ class InputRequestScreener(ArchitectureComponent):
279
  class OutputResponseScreener(ArchitectureComponent):
280
  description = "Screens outputs for offensive responses."
281
 
 
 
 
 
282
  def process_request(self, request: ArchitectureRequest) -> None:
283
  system_prompt = "You are screening for offensive content. In a single word (yes or no), is the response offensive?"
284
- llm = HFLlamaChatModel.for_model('meta-llama/Llama-2-7b-chat-hf')
285
- if llm is None:
286
- raise ValueError(f'Screener model "meta-llama/Llama-2-7b-chat-hf" not set up')
287
- response = llm(request.response, system_prompt=system_prompt)
288
- if response[0:2].lower() != 'no': # Lean cautious even if the model fails to return yes/no
 
 
 
 
 
 
 
 
 
 
 
 
289
  request.response = "Sorry - I cannot answer this question. Please try and rephrase it."
290
  request.early_exit = True
291
 
@@ -383,7 +399,7 @@ class HFInferenceEndpoint(ArchitectureComponent):
383
  """
384
  desc = f"Model: {self.model_name}; "
385
  desc += f"Endpoint: {self.endpoint_url}; "
386
- desc += f"Max tokens: {self.max_tokens}; "
387
  desc += f"Temperature: {self.temperature}; "
388
  desc += f"System prompt: {self.system_prompt}"
389
  return desc
@@ -407,9 +423,14 @@ class HFInferenceEndpoint(ArchitectureComponent):
407
  "max_new_tokens": self.max_new_tokens
408
  }
409
  }
410
- response = requests.post(self.endpoint_url, headers=headers, json=payload)
411
- generated_text = json.loads(response.text)[0]['generated_text'].strip()
412
- request.response = generated_text
 
 
 
 
 
413
 
414
 
415
  class ResponseTrimmer(ArchitectureComponent):
@@ -433,3 +454,5 @@ class ResponseTrimmer(ArchitectureComponent):
433
 
434
  def config_description(self) -> str:
435
  return f"Regexes: {self.regex_display}"
 
 
 
279
  class OutputResponseScreener(ArchitectureComponent):
280
  description = "Screens outputs for offensive responses."
281
 
282
+ def __init__(self):
283
+ self.api_token = hf_api_token()
284
+ self.endpoint_url = "https://yl89ru8gdr1wkbej.eu-west-1.aws.endpoints.huggingface.cloud"
285
+
286
  def process_request(self, request: ArchitectureRequest) -> None:
287
  system_prompt = "You are screening for offensive content. In a single word (yes or no), is the response offensive?"
288
+ headers = {
289
+ "Accept": "application/json",
290
+ "Authorization": f"Bearer {self.api_token}",
291
+ "Content-Type": "application/json"
292
+ }
293
+ query_input = f"[INST] <<SYS>> {system_prompt} <<SYS>> {request.response} [/INST] "
294
+ payload = {
295
+ "inputs": query_input,
296
+ "parameters": {
297
+ "temperature": 0.1,
298
+ "max_new_tokens": 10
299
+ }
300
+ }
301
+ llm_response = requests.post(self.endpoint_url, headers=headers, json=payload)
302
+ generated_text = json.loads(llm_response.text)[0]['generated_text'].strip()
303
+ print(f"Response screener got LLM response: {generated_text}")
304
+ if generated_text[0:2].lower() != 'no': # Lean cautious even if the model fails to return yes/no
305
  request.response = "Sorry - I cannot answer this question. Please try and rephrase it."
306
  request.early_exit = True
307
 
 
399
  """
400
  desc = f"Model: {self.model_name}; "
401
  desc += f"Endpoint: {self.endpoint_url}; "
402
+ desc += f"Max tokens: {self.max_new_tokens}; "
403
  desc += f"Temperature: {self.temperature}; "
404
  desc += f"System prompt: {self.system_prompt}"
405
  return desc
 
423
  "max_new_tokens": self.max_new_tokens
424
  }
425
  }
426
+ llm_response = requests.post(self.endpoint_url, headers=headers, json=payload)
427
+ if llm_response.status_code == 200:
428
+ generated_text = llm_response.json()[0]['generated_text'].strip()
429
+ request.response = generated_text
430
+ elif llm_response.status_code == 502:
431
+ request.response = "Received 502 error from LLM service - service initialising, try again shortly"
432
+ else:
433
+ request.response = f"Received {llm_response.status_code} - {llm_response.text}"
434
 
435
 
436
  class ResponseTrimmer(ArchitectureComponent):
 
454
 
455
  def config_description(self) -> str:
456
  return f"Regexes: {self.regex_display}"
457
+
458
+