Spaces:
Runtime error
Runtime error
Set up configuration for models on HF and an associated page on the application to allow end use test chat.
Browse files- config/models.json +9 -0
- pages/005_LLM_Models.py +58 -0
- src/common.py +1 -0
- src/models.py +64 -0
config/models.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": [
|
3 |
+
{
|
4 |
+
"name": "Llama2 Chat 7B",
|
5 |
+
"id": "meta-llama/Llama-2-7b-chat-hf",
|
6 |
+
"description": "The unmodified 7 billion parameter version of the llama 2 chat model from meta."
|
7 |
+
}
|
8 |
+
]
|
9 |
+
}
|
pages/005_LLM_Models.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from src.models import HFLlamaChatModel
|
4 |
+
from src.st_helpers import st_setup
|
5 |
+
|
6 |
+
if st_setup('LLM Models'):
|
7 |
+
st.write("# LLM Models")
|
8 |
+
st.write("The project uses a number of different models which are deployed with other components to form a variety of architectures. This page lists those models, and allows users to interact in isolation just with the model directly, excluding any other architecture components.")
|
9 |
+
|
10 |
+
SESSION_KEY_CHAT_SERVER = 'chat_server'
|
11 |
+
HF_AUTH_KEY_SECRET = 'hf_token'
|
12 |
+
button_count = 0
|
13 |
+
|
14 |
+
|
15 |
+
def button_key() -> str:
|
16 |
+
global button_count
|
17 |
+
button_count += 1
|
18 |
+
return f"btn_{button_count}"
|
19 |
+
|
20 |
+
server_container = st.container()
|
21 |
+
chat_container = st.container()
|
22 |
+
|
23 |
+
with server_container:
|
24 |
+
server_count = len(HFLlamaChatModel.available_models())
|
25 |
+
if server_count == 1:
|
26 |
+
st.write(f'### 1 server configured')
|
27 |
+
else:
|
28 |
+
st.write(f'### {server_count} servers configured')
|
29 |
+
|
30 |
+
with st.container():
|
31 |
+
for i, m_name in enumerate(HFLlamaChatModel.available_models()):
|
32 |
+
with st.container(): # row
|
33 |
+
content, actions = st.columns([4, 1])
|
34 |
+
with content:
|
35 |
+
st.write(m_name)
|
36 |
+
|
37 |
+
with actions:
|
38 |
+
if st.button("Chat with server", key=button_key()):
|
39 |
+
st.session_state[SESSION_KEY_CHAT_SERVER] = m_name
|
40 |
+
st.rerun()
|
41 |
+
if i != len(HFLlamaChatModel.available_models()) - 1:
|
42 |
+
st.divider()
|
43 |
+
|
44 |
+
if SESSION_KEY_CHAT_SERVER in st.session_state:
|
45 |
+
with chat_container:
|
46 |
+
st.write(f"### Chatting with {st.session_state[SESSION_KEY_CHAT_SERVER]}")
|
47 |
+
st.write(
|
48 |
+
"Note this is a simple single prompt call back to the relevant chat server. This is just a toy so you can interact with it and does not manage a chat session history.")
|
49 |
+
with st.chat_message("assistant"):
|
50 |
+
st.write("Chat with me in the box below")
|
51 |
+
if prompt := st.chat_input("Ask a question"):
|
52 |
+
with chat_container:
|
53 |
+
with st.chat_message("user"):
|
54 |
+
st.write(prompt)
|
55 |
+
chat_model = HFLlamaChatModel.get_model(st.session_state[SESSION_KEY_CHAT_SERVER])
|
56 |
+
response = chat_model(prompt, st.secrets[HF_AUTH_KEY_SECRET])
|
57 |
+
with st.chat_message("assistant"):
|
58 |
+
st.write(response)
|
src/common.py
CHANGED
@@ -2,3 +2,4 @@ import os
|
|
2 |
|
3 |
|
4 |
data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
|
|
|
|
2 |
|
3 |
|
4 |
data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
|
5 |
+
config_dir = os.path.join(os.path.dirname(__file__), '..', 'config')
|
src/models.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import requests
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from src.common import config_dir
|
7 |
+
|
8 |
+
|
9 |
+
class HFLlamaChatModel:
|
10 |
+
models = None
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
def load_configs(cls):
|
14 |
+
config_file = os.path.join(config_dir, "models.json")
|
15 |
+
with open(config_file, "r") as f:
|
16 |
+
configs = json.load(f)['models']
|
17 |
+
cls.models = []
|
18 |
+
for cfg in configs:
|
19 |
+
if cls.get_model(cfg['name']) is None:
|
20 |
+
cls.models.append(HFLlamaChatModel(cfg['name'], cfg['id'], cfg['description']))
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def get_model(cls, model: str):
|
24 |
+
for m in cls.models:
|
25 |
+
if m.name == model:
|
26 |
+
return m
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def available_models(cls) -> List[str]:
|
30 |
+
if cls.models is None:
|
31 |
+
cls.load_configs()
|
32 |
+
return [m.name for m in cls.models]
|
33 |
+
|
34 |
+
def __init__(self, name: str, id: str, description: str):
|
35 |
+
self.name = name
|
36 |
+
self.id = id
|
37 |
+
self.description = description
|
38 |
+
|
39 |
+
def __call__(self,
|
40 |
+
query: str,
|
41 |
+
auth_token: str,
|
42 |
+
system_prompt: str = None,
|
43 |
+
max_new_tokens: str = 256,
|
44 |
+
temperature: float = 1.0):
|
45 |
+
headers = {"Authorization": f"Bearer {auth_token}"}
|
46 |
+
api_url = f"https://api-inference.huggingface.co/models/{self.id}"
|
47 |
+
if system_prompt is None:
|
48 |
+
system_prompt = "You are a helpful assistant."
|
49 |
+
query_input = f"[INST] <<SYS>> {system_prompt} <<SYS>> {query} [/INST] "
|
50 |
+
query_payload = {
|
51 |
+
"inputs": query_input,
|
52 |
+
"parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature}
|
53 |
+
}
|
54 |
+
print(query_payload)
|
55 |
+
response = requests.post(api_url, headers=headers, json=query_payload)
|
56 |
+
if response.status_code == 200:
|
57 |
+
resp_json = json.loads(response.text)
|
58 |
+
llm_text = resp_json[0]['generated_text']
|
59 |
+
query_len = len(query_input)
|
60 |
+
llm_text = llm_text[query_len:].strip()
|
61 |
+
return llm_text
|
62 |
+
else:
|
63 |
+
error_detail = f"Error from hugging face code: {response.status_code}: {response.reason} ({response.content})"
|
64 |
+
raise ValueError(error_detail)
|