Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import streamlit as st | |
| from time import time | |
| from src.st_helpers import st_setup | |
| from src.data_synthesis.test_question_generator import generate_question | |
| from src.common import img_dir, escape_dollars, generate_group_tag | |
| from src.architectures import * | |
| COMPARE = "Side by side compare" # Constant value to use for the UI to select a side by side architecture comparison | |
| def show_side_by_side() -> None: | |
| """ | |
| Streamlit render a prompt and the boxes to show each architecture | |
| :return: | |
| """ | |
| # Build the layout structure | |
| st.divider() | |
| header_container = st.container() | |
| arch_outer_container = st.container() | |
| # Build header | |
| with header_container: | |
| st.write("### Side by side comparison of architectures") | |
| st.write('Enter a question below to have it sent to the selected architectures to compare timing and response.') | |
| options = [a.name for a in Architecture.architectures] | |
| selected_archs = st.multiselect("Select architectures to use", options=options, default=options) | |
| if len(selected_archs) == 0: | |
| st.write("To get started select some architectures to compare") | |
| else: | |
| prompt = st.chat_input("Ask a question") | |
| if st.button("Or press to ask a random question"): | |
| prompt = generate_question() | |
| if prompt: | |
| st.write(f"**Question:** {prompt}") | |
| # Now build the columns | |
| if len(selected_archs) > 0: | |
| with arch_outer_container: | |
| arch_cols = st.columns(len(selected_archs)) | |
| if prompt: | |
| # Build columns per architecture | |
| for i, a in enumerate(selected_archs): | |
| with arch_cols[i]: | |
| st.write(f'#### {a}') | |
| # Now dispatch the messages per architecture | |
| group_tag = generate_group_tag() | |
| for i, a in enumerate(selected_archs): | |
| request = ArchitectureRequest(query=prompt) | |
| arch = Architecture.get_architecture(a) | |
| with arch_cols[i]: | |
| with st.spinner('Architecture processing request'): | |
| start = time() | |
| arch(request, trace_tags=["UI", "SideBySideCompare", group_tag]) | |
| elapsed_in_s = (int((time() - start) * 10))/10 # round to 1dp in seconds | |
| st.write('##### Timing') | |
| st.write(f'Request took **{elapsed_in_s}s**') | |
| st.write('##### Response') | |
| st.write(request.response) | |
| else: | |
| # Build columns per architecture for display only | |
| for i, a in enumerate(selected_archs): | |
| with arch_cols[i]: | |
| st.write(f'#### {a}') | |
| def display_architecture_in_container(arch, arch_container) -> None: | |
| with arch_container: | |
| st.divider() | |
| st.write(f'### {arch.name}') | |
| st.write('#### Architecture description') | |
| st.write(arch.description) | |
| if arch.img is not None: | |
| img = os.path.join(img_dir, arch.img) | |
| st.image(img, caption=f'{arch.name} As Built', width=1000) | |
| table_data = [] | |
| for j, s in enumerate(arch.steps, start=1): | |
| table_data.append( | |
| [j, s.__class__.__name__, s.description, s.config_description()] | |
| ) | |
| table_cols = ['Step', 'Name', 'Description', 'Config details'] | |
| st.write('#### Architecture pipeline steps') | |
| st.table(pd.DataFrame(table_data, columns=table_cols)) | |
| def display_architecture_chat_in_container(arch, chat_container) -> None: | |
| with chat_container: | |
| st.write(f"### Chat with {arch.name}") | |
| st.write("Note this is a simple single query through the relevant architecture. This is just a sample so you can interact with it and does not manage a chat session history.") | |
| prompt = st.chat_input("Ask a question") | |
| if st.button("Or press to ask a random question"): | |
| prompt = generate_question() | |
| chat_col, trace_col, request_col = st.columns([3, 2, 2]) | |
| with chat_col: | |
| with st.chat_message("assistant"): | |
| st.write("Chat with me in the box below") | |
| if prompt: | |
| with chat_col: | |
| with st.chat_message("user"): | |
| st.write(prompt) | |
| request = ArchitectureRequest(query=prompt) | |
| trace = arch(request, trace_tags=["UI", "SingleArchTest"]) | |
| with st.chat_message("assistant"): | |
| st.write(escape_dollars(request.response)) | |
| with trace_col: | |
| st.write("#### Architecture Trace") | |
| st.markdown(trace.as_markdown()) | |
| with request_col: | |
| st.write("#### Full Request/Response") | |
| st.markdown(request.as_markdown()) | |
| def show_architecture(architecture: str) -> None: | |
| """ | |
| Streamlit render an architecture details and the | |
| ability to interact with the architecture | |
| :param architecture: the name of the architecture to output | |
| """ | |
| arch = Architecture.get_architecture(architecture) | |
| # Segment into two containers for organisation | |
| arch_container = st.container() | |
| chat_container = st.container() | |
| display_architecture_in_container(arch, arch_container) | |
| display_architecture_chat_in_container(arch, chat_container) | |
| def show_sub_header() -> None: | |
| """ | |
| Write a subheader to the page depending on how many architectures are configured | |
| """ | |
| arch_count = len(Architecture.architectures) | |
| if arch_count == 1: | |
| st.write('### 1 Architecture available') | |
| else: | |
| st.write(f'### {arch_count} Architectures available') | |
| def show_reload_button() -> None: | |
| """ | |
| Shows a button to reload the architectures and force them to reload if clicked | |
| """ | |
| if st.button("Force reload of architecture configs"): | |
| Architecture.load_architectures(force_reload=True) | |
| def get_user_selected_architecture() -> Optional[str]: | |
| """ | |
| Display a picker of all the architectures plus the option to do a side by side compare | |
| """ | |
| arch_names = [a.name for a in Architecture.architectures] | |
| arch_names.append(COMPARE) | |
| return st.radio(label="Available architectures", label_visibility="hidden", options=arch_names, index=None) | |
| if st_setup('LLM Arch'): | |
| st.write("# LLM Architectures") | |
| Architecture.load_architectures() | |
| show_sub_header() | |
| show_reload_button() | |
| selected_arch = get_user_selected_architecture() | |
| if selected_arch is None: | |
| st.info('Select an architecture from above to see details and interact with it') | |
| elif selected_arch == COMPARE: | |
| show_side_by_side() | |
| else: | |
| show_architecture(selected_arch) | |