alfraser commited on
Commit
54b3256
·
1 Parent(s): 326698c

Set up configuration for models on HF and an associated page on the application to allow end use test chat.

Browse files
Files changed (4) hide show
  1. config/models.json +9 -0
  2. pages/005_LLM_Models.py +58 -0
  3. src/common.py +1 -0
  4. src/models.py +64 -0
config/models.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": [
3
+ {
4
+ "name": "Llama2 Chat 7B",
5
+ "id": "meta-llama/Llama-2-7b-chat-hf",
6
+ "description": "The unmodified 7 billion parameter version of the llama 2 chat model from meta."
7
+ }
8
+ ]
9
+ }
pages/005_LLM_Models.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from src.models import HFLlamaChatModel
4
+ from src.st_helpers import st_setup
5
+
6
+ if st_setup('LLM Models'):
7
+ st.write("# LLM Models")
8
+ st.write("The project uses a number of different models which are deployed with other components to form a variety of architectures. This page lists those models, and allows users to interact in isolation just with the model directly, excluding any other architecture components.")
9
+
10
+ SESSION_KEY_CHAT_SERVER = 'chat_server'
11
+ HF_AUTH_KEY_SECRET = 'hf_token'
12
+ button_count = 0
13
+
14
+
15
+ def button_key() -> str:
16
+ global button_count
17
+ button_count += 1
18
+ return f"btn_{button_count}"
19
+
20
+ server_container = st.container()
21
+ chat_container = st.container()
22
+
23
+ with server_container:
24
+ server_count = len(HFLlamaChatModel.available_models())
25
+ if server_count == 1:
26
+ st.write(f'### 1 server configured')
27
+ else:
28
+ st.write(f'### {server_count} servers configured')
29
+
30
+ with st.container():
31
+ for i, m_name in enumerate(HFLlamaChatModel.available_models()):
32
+ with st.container(): # row
33
+ content, actions = st.columns([4, 1])
34
+ with content:
35
+ st.write(m_name)
36
+
37
+ with actions:
38
+ if st.button("Chat with server", key=button_key()):
39
+ st.session_state[SESSION_KEY_CHAT_SERVER] = m_name
40
+ st.rerun()
41
+ if i != len(HFLlamaChatModel.available_models()) - 1:
42
+ st.divider()
43
+
44
+ if SESSION_KEY_CHAT_SERVER in st.session_state:
45
+ with chat_container:
46
+ st.write(f"### Chatting with {st.session_state[SESSION_KEY_CHAT_SERVER]}")
47
+ st.write(
48
+ "Note this is a simple single prompt call back to the relevant chat server. This is just a toy so you can interact with it and does not manage a chat session history.")
49
+ with st.chat_message("assistant"):
50
+ st.write("Chat with me in the box below")
51
+ if prompt := st.chat_input("Ask a question"):
52
+ with chat_container:
53
+ with st.chat_message("user"):
54
+ st.write(prompt)
55
+ chat_model = HFLlamaChatModel.get_model(st.session_state[SESSION_KEY_CHAT_SERVER])
56
+ response = chat_model(prompt, st.secrets[HF_AUTH_KEY_SECRET])
57
+ with st.chat_message("assistant"):
58
+ st.write(response)
src/common.py CHANGED
@@ -2,3 +2,4 @@ import os
2
 
3
 
4
  data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
 
 
2
 
3
 
4
  data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
5
+ config_dir = os.path.join(os.path.dirname(__file__), '..', 'config')
src/models.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import requests
4
+ from typing import List
5
+
6
+ from src.common import config_dir
7
+
8
+
9
+ class HFLlamaChatModel:
10
+ models = None
11
+
12
+ @classmethod
13
+ def load_configs(cls):
14
+ config_file = os.path.join(config_dir, "models.json")
15
+ with open(config_file, "r") as f:
16
+ configs = json.load(f)['models']
17
+ cls.models = []
18
+ for cfg in configs:
19
+ if cls.get_model(cfg['name']) is None:
20
+ cls.models.append(HFLlamaChatModel(cfg['name'], cfg['id'], cfg['description']))
21
+
22
+ @classmethod
23
+ def get_model(cls, model: str):
24
+ for m in cls.models:
25
+ if m.name == model:
26
+ return m
27
+
28
+ @classmethod
29
+ def available_models(cls) -> List[str]:
30
+ if cls.models is None:
31
+ cls.load_configs()
32
+ return [m.name for m in cls.models]
33
+
34
+ def __init__(self, name: str, id: str, description: str):
35
+ self.name = name
36
+ self.id = id
37
+ self.description = description
38
+
39
+ def __call__(self,
40
+ query: str,
41
+ auth_token: str,
42
+ system_prompt: str = None,
43
+ max_new_tokens: str = 256,
44
+ temperature: float = 1.0):
45
+ headers = {"Authorization": f"Bearer {auth_token}"}
46
+ api_url = f"https://api-inference.huggingface.co/models/{self.id}"
47
+ if system_prompt is None:
48
+ system_prompt = "You are a helpful assistant."
49
+ query_input = f"[INST] <<SYS>> {system_prompt} <<SYS>> {query} [/INST] "
50
+ query_payload = {
51
+ "inputs": query_input,
52
+ "parameters": {"max_new_tokens": max_new_tokens, "temperature": temperature}
53
+ }
54
+ print(query_payload)
55
+ response = requests.post(api_url, headers=headers, json=query_payload)
56
+ if response.status_code == 200:
57
+ resp_json = json.loads(response.text)
58
+ llm_text = resp_json[0]['generated_text']
59
+ query_len = len(query_input)
60
+ llm_text = llm_text[query_len:].strip()
61
+ return llm_text
62
+ else:
63
+ error_detail = f"Error from hugging face code: {response.status_code}: {response.reason} ({response.content})"
64
+ raise ValueError(error_detail)