File size: 4,142 Bytes
e18c8b0 19ee1e4 e18c8b0 a6b53fb e18c8b0 7d317a5 e18c8b0 5bfd036 e18c8b0 7d317a5 e18c8b0 d4ce695 e18c8b0 7bcd3fa e18c8b0 19ee1e4 7d317a5 e18c8b0 d4ce695 5bfd036 e18c8b0 19ee1e4 e18c8b0 19ee1e4 e18c8b0 19ee1e4 e18c8b0 19ee1e4 e18c8b0 19ee1e4 e18c8b0 19ee1e4 e18c8b0 19ee1e4 e18c8b0 a6b53fb e18c8b0 a6b53fb 7bcd3fa e18c8b0 a6b53fb 7bcd3fa e18c8b0 a6b53fb e18c8b0 a6b53fb e18c8b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import logging
import os
import traceback
import pandas as pd
import streamlit as st
from llm_guard.vault import Vault
from output import init_settings as init_output_settings
from output import scan as scan_output
from prompt import init_settings as init_prompt_settings
from prompt import scan as scan_prompt
PROMPT = "prompt"
OUTPUT = "output"
vault = Vault()
st.set_page_config(
page_title="LLM Guard Playground",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
"About": "https://llm-guard.com/",
},
)
logger = logging.getLogger("llm-guard-playground")
logger.setLevel(logging.INFO)
# Sidebar
st.sidebar.header(
"""
Scanning prompt and output using [LLM Guard](https://llm-guard.com/)
"""
)
scanner_type = st.sidebar.selectbox("Type", [PROMPT, OUTPUT], index=0)
st_fail_fast = st.sidebar.checkbox(
"Fail fast", value=False, help="Stop scanning after first failure"
)
enabled_scanners = None
settings = None
if scanner_type == PROMPT:
enabled_scanners, settings = init_prompt_settings()
elif scanner_type == OUTPUT:
enabled_scanners, settings = init_output_settings()
# Main pannel
st.subheader("Guard Prompt" if scanner_type == PROMPT else "Guard Output")
with st.expander("About", expanded=False):
st.info(
"""LLM-Guard is a comprehensive tool designed to fortify the security of Large Language Models (LLMs).
\n\n[Code](https://github.com/protectai/llm-guard) |
[Documentation](https://llm-guard.com/)"""
)
analyzer_load_state = st.info("Starting LLM Guard...")
analyzer_load_state.empty()
# Before:
prompt_examples_folder = "./examples/prompt"
output_examples_folder = "./examples/output"
prompt_examples = [f for f in os.listdir(prompt_examples_folder) if f.endswith(".txt")]
output_examples = [f for f in os.listdir(output_examples_folder) if f.endswith(".txt")]
if scanner_type == PROMPT:
st_prompt_example = st.selectbox("Select prompt example", prompt_examples, index=0)
with open(os.path.join(prompt_examples_folder, st_prompt_example), "r") as file:
prompt_example_text = file.read()
st_prompt_text = st.text_area(
label="Enter prompt", value=prompt_example_text, height=200, key="prompt_text_input"
)
elif scanner_type == OUTPUT:
col1, col2 = st.columns(2)
st_prompt_example = col1.selectbox("Select prompt example", prompt_examples, index=0)
with open(os.path.join(prompt_examples_folder, st_prompt_example), "r") as file:
prompt_example_text = file.read()
st_prompt_text = col1.text_area(
label="Enter prompt", value=prompt_example_text, height=300, key="prompt_text_input"
)
st_output_example = col2.selectbox("Select output example", output_examples, index=0)
with open(os.path.join(output_examples_folder, st_output_example), "r") as file:
output_example_text = file.read()
st_output_text = col2.text_area(
label="Enter output", value=output_example_text, height=300, key="output_text_input"
)
st_result_text = None
st_analysis = None
st_is_valid = None
try:
with st.form("text_form", clear_on_submit=False):
submitted = st.form_submit_button("Process")
if submitted:
results = {}
if scanner_type == PROMPT:
st_result_text, results = scan_prompt(
vault, enabled_scanners, settings, st_prompt_text, st_fail_fast
)
elif scanner_type == OUTPUT:
st_result_text, results = scan_output(
vault, enabled_scanners, settings, st_prompt_text, st_output_text, st_fail_fast
)
st_is_valid = all(item["is_valid"] for item in results)
st_analysis = results
except Exception as e:
logger.error(e)
traceback.print_exc()
st.error(e)
# After:
if st_is_valid is not None:
st.subheader(f"Results - {'valid' if st_is_valid else 'invalid'}")
col1, col2 = st.columns(2)
with col1:
st.text_area(label="Sanitized text", value=st_result_text, height=400)
with col2:
st.table(pd.DataFrame(st_analysis))
|