llm-arch / pages /010_LLM_Architectures.py
alfraser's picture
Refactored pages to make the functions smaller and clearer
03dc960
raw
history blame
6.94 kB
import pandas as pd
import streamlit as st
from time import time
from src.st_helpers import st_setup
from src.data_synthesis.test_question_generator import generate_question
from src.common import img_dir, escape_dollars, generate_group_tag
from src.architectures import *
COMPARE = "Side by side compare" # Constant value to use for the UI to select a side by side architecture comparison
def show_side_by_side() -> None:
"""
Streamlit render a prompt and the boxes to show each architecture
:return:
"""
# Build the layout structure
st.divider()
header_container = st.container()
arch_outer_container = st.container()
# Build header
with header_container:
st.write("### Side by side comparison of architectures")
st.write('Enter a question below to have it sent to the selected architectures to compare timing and response.')
options = [a.name for a in Architecture.architectures]
selected_archs = st.multiselect("Select architectures to use", options=options, default=options)
if len(selected_archs) == 0:
st.write("To get started select some architectures to compare")
else:
prompt = st.chat_input("Ask a question")
if st.button("Or press to ask a random question"):
prompt = generate_question()
if prompt:
st.write(f"**Question:** {prompt}")
# Now build the columns
if len(selected_archs) > 0:
with arch_outer_container:
arch_cols = st.columns(len(selected_archs))
if prompt:
# Build columns per architecture
for i, a in enumerate(selected_archs):
with arch_cols[i]:
st.write(f'#### {a}')
# Now dispatch the messages per architecture
group_tag = generate_group_tag()
for i, a in enumerate(selected_archs):
request = ArchitectureRequest(query=prompt)
arch = Architecture.get_architecture(a)
with arch_cols[i]:
with st.spinner('Architecture processing request'):
start = time()
arch(request, trace_tags=["UI", "SideBySideCompare", group_tag])
elapsed_in_s = (int((time() - start) * 10))/10 # round to 1dp in seconds
st.write('##### Timing')
st.write(f'Request took **{elapsed_in_s}s**')
st.write('##### Response')
st.write(request.response)
else:
# Build columns per architecture for display only
for i, a in enumerate(selected_archs):
with arch_cols[i]:
st.write(f'#### {a}')
def display_architecture_in_container(arch, arch_container) -> None:
with arch_container:
st.divider()
st.write(f'### {arch.name}')
st.write('#### Architecture description')
st.write(arch.description)
if arch.img is not None:
img = os.path.join(img_dir, arch.img)
st.image(img, caption=f'{arch.name} As Built', width=1000)
table_data = []
for j, s in enumerate(arch.steps, start=1):
table_data.append(
[j, s.__class__.__name__, s.description, s.config_description()]
)
table_cols = ['Step', 'Name', 'Description', 'Config details']
st.write('#### Architecture pipeline steps')
st.table(pd.DataFrame(table_data, columns=table_cols))
def display_architecture_chat_in_container(arch, chat_container) -> None:
with chat_container:
st.write(f"### Chat with {arch.name}")
st.write("Note this is a simple single query through the relevant architecture. This is just a sample so you can interact with it and does not manage a chat session history.")
prompt = st.chat_input("Ask a question")
if st.button("Or press to ask a random question"):
prompt = generate_question()
chat_col, trace_col, request_col = st.columns([3, 2, 2])
with chat_col:
with st.chat_message("assistant"):
st.write("Chat with me in the box below")
if prompt:
with chat_col:
with st.chat_message("user"):
st.write(prompt)
request = ArchitectureRequest(query=prompt)
trace = arch(request, trace_tags=["UI", "SingleArchTest"])
with st.chat_message("assistant"):
st.write(escape_dollars(request.response))
with trace_col:
st.write("#### Architecture Trace")
st.markdown(trace.as_markdown())
with request_col:
st.write("#### Full Request/Response")
st.markdown(request.as_markdown())
def show_architecture(architecture: str) -> None:
"""
Streamlit render an architecture details and the
ability to interact with the architecture
:param architecture: the name of the architecture to output
"""
arch = Architecture.get_architecture(architecture)
# Segment into two containers for organisation
arch_container = st.container()
chat_container = st.container()
display_architecture_in_container(arch, arch_container)
display_architecture_chat_in_container(arch, chat_container)
def show_sub_header() -> None:
"""
Write a subheader to the page depending on how many architectures are configured
"""
arch_count = len(Architecture.architectures)
if arch_count == 1:
st.write('### 1 Architecture available')
else:
st.write(f'### {arch_count} Architectures available')
def show_reload_button() -> None:
"""
Shows a button to reload the architectures and force them to reload if clicked
"""
if st.button("Force reload of architecture configs"):
Architecture.load_architectures(force_reload=True)
def get_user_selected_architecture() -> Optional[str]:
"""
Display a picker of all the architectures plus the option to do a side by side compare
"""
arch_names = [a.name for a in Architecture.architectures]
arch_names.append(COMPARE)
return st.radio(label="Available architectures", label_visibility="hidden", options=arch_names, index=None)
if st_setup('LLM Arch'):
st.write("# LLM Architectures")
Architecture.load_architectures()
show_sub_header()
show_reload_button()
selected_arch = get_user_selected_architecture()
if selected_arch is None:
st.info('Select an architecture from above to see details and interact with it')
elif selected_arch == COMPARE:
show_side_by_side()
else:
show_architecture(selected_arch)