asofter's picture
* initial
e18c8b0
raw
history blame
11.8 kB
import logging
from typing import Dict, List
import streamlit as st
from streamlit_tags import st_tags
from llm_guard.input_scanners.anonymize import default_entity_types
from llm_guard.output_scanners import (
BanSubstrings,
BanTopics,
Bias,
Code,
Deanonymize,
MaliciousURLs,
NoRefusal,
Refutation,
Regex,
Relevance,
Sensitive,
)
from llm_guard.output_scanners.sentiment import Sentiment
from llm_guard.output_scanners.toxicity import Toxicity
from llm_guard.vault import Vault
logger = logging.getLogger("llm-guard-demo")
def init_settings() -> (List, Dict):
all_scanners = [
"BanSubstrings",
"BanTopics",
"Bias",
"Code",
"Deanonymize",
"MaliciousURLs",
"NoRefusal",
"Refutation",
"Regex",
"Relevance",
"Sensitive",
"Sentiment",
"Toxicity",
]
st_enabled_scanners = st.sidebar.multiselect(
"Select scanners",
options=all_scanners,
default=all_scanners,
help="The list can be found here: https://laiyer-ai.github.io/llm-guard/output_scanners/bias/",
)
settings = {}
if "BanSubstrings" in st_enabled_scanners:
st_bs_expander = st.sidebar.expander(
"Ban Substrings",
expanded=False,
)
with st_bs_expander:
st_bs_substrings = st.text_area(
"Enter substrings to ban (one per line)",
value="test\nhello\nworld\n",
height=200,
).split("\n")
st_bs_match_type = st.selectbox("Match type", ["str", "word"])
st_bs_case_sensitive = st.checkbox("Case sensitive", value=False)
settings["BanSubstrings"] = {
"substrings": st_bs_substrings,
"match_type": st_bs_match_type,
"case_sensitive": st_bs_case_sensitive,
}
if "BanTopics" in st_enabled_scanners:
st_bt_expander = st.sidebar.expander(
"Ban Topics",
expanded=False,
)
with st_bt_expander:
st_bt_topics = st_tags(
label="List of topics",
text="Type and press enter",
value=["politics", "religion", "money", "crime"],
suggestions=[],
maxtags=30,
key="bt_topics",
)
st_bt_threshold = st.slider(
label="Threshold",
value=0.75,
min_value=0.0,
max_value=1.0,
step=0.05,
key="ban_topics_threshold",
)
settings["BanTopics"] = {"topics": st_bt_topics, "threshold": st_bt_threshold}
if "Bias" in st_enabled_scanners:
st_bias_expander = st.sidebar.expander(
"Bias",
expanded=False,
)
with st_bias_expander:
st_bias_threshold = st.slider(
label="Threshold",
value=0.75,
min_value=0.0,
max_value=1.0,
step=0.05,
key="bias_threshold",
)
settings["Bias"] = {"threshold": st_bias_threshold}
if "Code" in st_enabled_scanners:
st_cd_expander = st.sidebar.expander(
"Code",
expanded=False,
)
with st_cd_expander:
st_cd_languages = st.multiselect(
"Programming languages",
options=["python", "java", "javascript", "go", "php", "ruby"],
default=["python"],
)
st_cd_mode = st.selectbox("Mode", ["allowed", "denied"], index=0)
settings["Code"] = {"languages": st_cd_languages, "mode": st_cd_mode}
if "MaliciousURLs" in st_enabled_scanners:
st_murls_expander = st.sidebar.expander(
"Malicious URLs",
expanded=False,
)
with st_murls_expander:
st_murls_threshold = st.slider(
label="Threshold",
value=0.75,
min_value=0.0,
max_value=1.0,
step=0.05,
key="murls_threshold",
)
settings["MaliciousURLs"] = {"threshold": st_murls_threshold}
if "NoRefusal" in st_enabled_scanners:
st_no_ref_expander = st.sidebar.expander(
"No refusal",
expanded=False,
)
with st_no_ref_expander:
st_no_ref_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="no_ref_threshold",
)
settings["NoRefusal"] = {"threshold": st_no_ref_threshold}
if "Refutation" in st_enabled_scanners:
st_refu_expander = st.sidebar.expander(
"Refutation",
expanded=False,
)
with st_refu_expander:
st_refu_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="refu_threshold",
)
settings["Refutation"] = {"threshold": st_refu_threshold}
if "Regex" in st_enabled_scanners:
st_regex_expander = st.sidebar.expander(
"Regex",
expanded=False,
)
with st_regex_expander:
st_regex_patterns = st.text_area(
"Enter patterns to ban (one per line)",
value="Bearer [A-Za-z0-9-._~+/]+",
height=200,
).split("\n")
st_regex_type = st.selectbox(
"Match type",
["good", "bad"],
index=1,
help="good: allow only good patterns, bad: ban bad patterns",
)
settings["Regex"] = {"patterns": st_regex_patterns, "type": st_regex_type}
if "Relevance" in st_enabled_scanners:
st_rele_expander = st.sidebar.expander(
"Relevance",
expanded=False,
)
with st_rele_expander:
st_rele_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=-1.0,
max_value=1.0,
step=0.05,
key="rele_threshold",
help="The minimum cosine similarity (-1 to 1) between the prompt and output for the output to be considered relevant.",
)
settings["Relevance"] = {"threshold": st_rele_threshold}
if "Sensitive" in st_enabled_scanners:
st_sens_expander = st.sidebar.expander(
"Sensitive",
expanded=False,
)
with st_sens_expander:
st_sens_entity_types = st_tags(
label="Sensitive entities",
text="Type and press enter",
value=default_entity_types,
suggestions=default_entity_types
+ ["DATE_TIME", "NRP", "LOCATION", "MEDICAL_LICENSE", "US_PASSPORT"],
maxtags=30,
key="sensitive_entity_types",
)
st.caption(
"Check all supported entities: https://microsoft.github.io/presidio/supported_entities/#list-of-supported-entities"
)
settings["Sensitive"] = {"entity_types": st_sens_entity_types}
if "Sentiment" in st_enabled_scanners:
st_sent_expander = st.sidebar.expander(
"Sentiment",
expanded=False,
)
with st_sent_expander:
st_sent_threshold = st.slider(
label="Threshold",
value=-0.1,
min_value=-1.0,
max_value=1.0,
step=0.1,
key="sentiment_threshold",
help="Negative values are negative sentiment, positive values are positive sentiment",
)
settings["Sentiment"] = {"threshold": st_sent_threshold}
if "Toxicity" in st_enabled_scanners:
st_tox_expander = st.sidebar.expander(
"Toxicity",
expanded=False,
)
with st_tox_expander:
st_tox_threshold = st.slider(
label="Threshold",
value=0.0,
min_value=-1.0,
max_value=1.0,
step=0.05,
key="toxicity_threshold",
help="A negative value (closer to 0 as the label output) indicates toxicity in the text, while a positive logit (closer to 1 as the label output) suggests non-toxicity.",
)
settings["Toxicity"] = {"threshold": st_tox_threshold}
return st_enabled_scanners, settings
def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
logger.debug(f"Initializing {scanner_name} scanner")
if scanner_name == "BanSubstrings":
return BanSubstrings(
substrings=settings["substrings"],
match_type=settings["match_type"],
case_sensitive=settings["case_sensitive"],
)
if scanner_name == "BanTopics":
return BanTopics(topics=settings["topics"], threshold=settings["threshold"])
if scanner_name == "Bias":
return Bias(threshold=settings["threshold"])
if scanner_name == "Deanonymize":
return Deanonymize(vault=vault)
if scanner_name == "Code":
mode = settings["mode"]
allowed_languages = None
denied_languages = None
if mode == "allowed":
allowed_languages = settings["languages"]
elif mode == "denied":
denied_languages = settings["languages"]
return Code(allowed=allowed_languages, denied=denied_languages)
if scanner_name == "MaliciousURLs":
return MaliciousURLs(threshold=settings["threshold"])
if scanner_name == "NoRefusal":
return NoRefusal(threshold=settings["threshold"])
if scanner_name == "Refutation":
return Refutation(threshold=settings["threshold"])
if scanner_name == "Regex":
match_type = settings["type"]
good_patterns = None
bad_patterns = None
if match_type == "good":
good_patterns = settings["patterns"]
elif match_type == "bad":
bad_patterns = settings["patterns"]
return Regex(good_patterns=good_patterns, bad_patterns=bad_patterns)
if scanner_name == "Relevance":
return Relevance(threshold=settings["threshold"])
if scanner_name == "Sensitive":
return Sensitive(entity_types=settings["entity_types"])
if scanner_name == "Sentiment":
return Sentiment(threshold=settings["threshold"])
if scanner_name == "Toxicity":
return Toxicity(threshold=settings["threshold"])
raise ValueError("Unknown scanner name")
def scan(
vault: Vault, enabled_scanners: List[str], settings: Dict, prompt: str, text: str
) -> (str, Dict[str, bool], Dict[str, float]):
sanitized_output = text
results_valid = {}
results_score = {}
with st.status("Scanning output...", expanded=True) as status:
for scanner_name in enabled_scanners:
st.write(f"{scanner_name} scanner...")
scanner = get_scanner(
scanner_name, vault, settings[scanner_name] if scanner_name in settings else {}
)
sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)
results_valid[scanner_name] = is_valid
results_score[scanner_name] = risk_score
status.update(label="Scanning complete", state="complete", expanded=False)
return sanitized_output, results_valid, results_score