alfraser commited on
Commit
6000142
·
1 Parent(s): 0cd8882

Added a new architecture component which calls hugging face via a dedicated inference end point and not the HTTP interface - needed due to the size of the fine-tuned model

Browse files
Files changed (1) hide show
  1. src/architectures.py +47 -1
src/architectures.py CHANGED
@@ -7,6 +7,7 @@ import chromadb
7
  import json
8
  import os
9
  import regex as re
 
10
  import traceback
11
 
12
  from abc import ABC, abstractmethod
@@ -363,6 +364,52 @@ class HFLlamaHttpRequestor(ArchitectureComponent):
363
  request.response = response
364
 
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  class ResponseTrimmer(ArchitectureComponent):
367
  """
368
  A concrete pipeline component which trims the response based on a regex match,
@@ -384,4 +431,3 @@ class ResponseTrimmer(ArchitectureComponent):
384
 
385
  def config_description(self) -> str:
386
  return f"Regexes: {self.regex_display}"
387
-
 
7
  import json
8
  import os
9
  import regex as re
10
+ import requests
11
  import traceback
12
 
13
  from abc import ABC, abstractmethod
 
364
  request.response = response
365
 
366
 
367
+ class HFInferenceEndpoint(ArchitectureComponent):
368
+ """
369
+ A concrete pipeline component which sends the user text to a given llama chat based
370
+ inference endpoint on HuggingFace
371
+ """
372
+ def __init__(self, endpoint_url: str, system_prompt: str, max_new_tokens: int, temperature: float = 1.0):
373
+ self.endpoint_url: str = endpoint_url
374
+ self.system_prompt: str = system_prompt
375
+ self.max_new_tokens = max_new_tokens
376
+ self.api_token = hf_api_token()
377
+ self.temperature = temperature
378
+
379
+ def config_description(self) -> str:
380
+ """
381
+ Custom config details as markdown
382
+ """
383
+ desc = f"Endpoint: {self.endpoint_url}; "
384
+ desc += f"Max tokens: {self.max_tokens}; "
385
+ desc += f"Temperature: {self.temperature}; "
386
+ desc += f"System prompt: {self.system_prompt}"
387
+ return desc
388
+
389
+ def process_request(self, request: ArchitectureRequest) -> None:
390
+ """
391
+ Main processing method for this function. Calls the HTTP service for the model
392
+ by port if provided or attempting to lookup by name, and then adds this to the
393
+ response element of the request.
394
+ """
395
+ headers = {
396
+ "Accept": "application/json",
397
+ "Authorization": f"Bearer {self.api_token}",
398
+ "Content-Type": "application/json"
399
+ }
400
+ query_input = f"[INST] <<SYS>> {self.system_prompt} <<SYS>> {request.request} [/INST] "
401
+ payload = {
402
+ "inputs": query_input,
403
+ "parameters": {
404
+ "temperature": self.temperature,
405
+ "max_new_tokens": self.max_new_tokens
406
+ }
407
+ }
408
+ response = requests.post(self.endpoint_url, headers=headers, json=payload)
409
+ generated_text = json.loads(response.text)[0]['generated_text'].strip()
410
+ request.response = generated_text
411
+
412
+
413
  class ResponseTrimmer(ArchitectureComponent):
414
  """
415
  A concrete pipeline component which trims the response based on a regex match,
 
431
 
432
  def config_description(self) -> str:
433
  return f"Regexes: {self.regex_display}"