Spaces:
Sleeping
Sleeping
File size: 3,013 Bytes
2f935de bc152f8 5d09640 2f935de 28eb1f5 2f935de 5d09640 2f935de d88280b 2f935de cd0c6eb 2f935de 2697a17 2f935de 5d09640 bc152f8 ecf0f3e bc152f8 a3e8717 2f935de d88280b 4760da5 2f935de cd0c6eb 2f935de a3e8717 2f935de ecf0f3e d88280b ecf0f3e 0a9972d ecf0f3e a3e8717 ecf0f3e a3e8717 0a9972d e5fca1e a3e8717 0a9972d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import requests
from PIL import Image
from io import BytesIO
# Define the template
TEMPLATE = """[INST] Evaluate if the given text is harmful i.e. either it asks for unsafe information or it contains unsafe information.
<START TEXT>
{prompt}
<END TEXT>
Answer: [/INST]
"""
# Load the model and tokenizer
@st.cache_resource
def load_model():
model_name = "walledai/walledguard-c"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return tokenizer, model
# Function to load image from URL
@st.cache_data()
def load_image_from_url(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content))
return img
# Evaluation function
def evaluate_text(user_input):
if user_input:
# Get model and tokenizer from session state
tokenizer, model = st.session_state.model_and_tokenizer
# Prepare input
input_ids = tokenizer.encode(TEMPLATE.format(prompt=user_input), return_tensors="pt")
# Generate output
output = model.generate(input_ids=input_ids, max_new_tokens=20, pad_token_id=0)
# Decode output
prompt_len = input_ids.shape[-1]
output_decoded = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
# Determine prediction
prediction = 'unsafe' if 'unsafe' in output_decoded.lower() else 'safe'
return prediction
return None
# Streamlit app
st.title("Text Safety Evaluator")
# Load model and tokenizer once and store in session state
if 'model_and_tokenizer' not in st.session_state:
st.session_state.model_and_tokenizer = load_model()
# User input
user_input = st.text_area("Enter the text you want to evaluate:", height=100)
# Create an empty container for the result
result_container = st.empty()
if st.button("Evaluate"):
prediction = evaluate_text(user_input)
if prediction:
result_container.subheader("Evaluation Result:")
result_container.write(f"The text is evaluated as: **{prediction.upper()}**")
else:
result_container.warning("Please enter some text to evaluate.")
# Add logo at the bottom center (only once)
#if 'logo_displayed' not in st.session_state:
col1, col2, col3 = st.columns([1,2,1])
with col2:
logo_url = "https://github.com/walledai/walledeval/assets/32847115/d8b1d14f-7071-448b-8997-2eeba4c2c8f6"
logo = load_image_from_url(logo_url)
st.image(logo, use_column_width=True, width=500) # Adjust the width as needed
#st.session_state.logo_displayed = True
# Add information about Walled Guard Advanced (only once)
#if 'info_displayed' not in st.session_state:
col1, col2, col3 = st.columns([1,2,1])
with col2:
st.info("For a more performant version, check out Walled Guard Advanced. Connect with us at [email protected] for more information.")
#st.session_state.info_displayed = True |