Spaces:
Runtime error
Runtime error
Added a new architecture component which calls hugging face via a dedicated inference end point and not the HTTP interface - needed due to the size of the fine-tuned model
Browse files- src/architectures.py +47 -1
src/architectures.py
CHANGED
@@ -7,6 +7,7 @@ import chromadb
|
|
7 |
import json
|
8 |
import os
|
9 |
import regex as re
|
|
|
10 |
import traceback
|
11 |
|
12 |
from abc import ABC, abstractmethod
|
@@ -363,6 +364,52 @@ class HFLlamaHttpRequestor(ArchitectureComponent):
|
|
363 |
request.response = response
|
364 |
|
365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
class ResponseTrimmer(ArchitectureComponent):
|
367 |
"""
|
368 |
A concrete pipeline component which trims the response based on a regex match,
|
@@ -384,4 +431,3 @@ class ResponseTrimmer(ArchitectureComponent):
|
|
384 |
|
385 |
def config_description(self) -> str:
|
386 |
return f"Regexes: {self.regex_display}"
|
387 |
-
|
|
|
7 |
import json
|
8 |
import os
|
9 |
import regex as re
|
10 |
+
import requests
|
11 |
import traceback
|
12 |
|
13 |
from abc import ABC, abstractmethod
|
|
|
364 |
request.response = response
|
365 |
|
366 |
|
367 |
+
class HFInferenceEndpoint(ArchitectureComponent):
|
368 |
+
"""
|
369 |
+
A concrete pipeline component which sends the user text to a given llama chat based
|
370 |
+
inference endpoint on HuggingFace
|
371 |
+
"""
|
372 |
+
def __init__(self, endpoint_url: str, system_prompt: str, max_new_tokens: int, temperature: float = 1.0):
|
373 |
+
self.endpoint_url: str = endpoint_url
|
374 |
+
self.system_prompt: str = system_prompt
|
375 |
+
self.max_new_tokens = max_new_tokens
|
376 |
+
self.api_token = hf_api_token()
|
377 |
+
self.temperature = temperature
|
378 |
+
|
379 |
+
def config_description(self) -> str:
|
380 |
+
"""
|
381 |
+
Custom config details as markdown
|
382 |
+
"""
|
383 |
+
desc = f"Endpoint: {self.endpoint_url}; "
|
384 |
+
desc += f"Max tokens: {self.max_tokens}; "
|
385 |
+
desc += f"Temperature: {self.temperature}; "
|
386 |
+
desc += f"System prompt: {self.system_prompt}"
|
387 |
+
return desc
|
388 |
+
|
389 |
+
def process_request(self, request: ArchitectureRequest) -> None:
|
390 |
+
"""
|
391 |
+
Main processing method for this function. Calls the HTTP service for the model
|
392 |
+
by port if provided or attempting to lookup by name, and then adds this to the
|
393 |
+
response element of the request.
|
394 |
+
"""
|
395 |
+
headers = {
|
396 |
+
"Accept": "application/json",
|
397 |
+
"Authorization": f"Bearer {self.api_token}",
|
398 |
+
"Content-Type": "application/json"
|
399 |
+
}
|
400 |
+
query_input = f"[INST] <<SYS>> {self.system_prompt} <<SYS>> {request.request} [/INST] "
|
401 |
+
payload = {
|
402 |
+
"inputs": query_input,
|
403 |
+
"parameters": {
|
404 |
+
"temperature": self.temperature,
|
405 |
+
"max_new_tokens": self.max_new_tokens
|
406 |
+
}
|
407 |
+
}
|
408 |
+
response = requests.post(self.endpoint_url, headers=headers, json=payload)
|
409 |
+
generated_text = json.loads(response.text)[0]['generated_text'].strip()
|
410 |
+
request.response = generated_text
|
411 |
+
|
412 |
+
|
413 |
class ResponseTrimmer(ArchitectureComponent):
|
414 |
"""
|
415 |
A concrete pipeline component which trims the response based on a regex match,
|
|
|
431 |
|
432 |
def config_description(self) -> str:
|
433 |
return f"Regexes: {self.regex_display}"
|
|