RishabhBhardwaj commited on
Commit
0a9972d
1 Parent(s): 8be32c4
Files changed (1) hide show
  1. app.py +13 -31
app.py CHANGED
@@ -3,7 +3,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import requests
4
  from PIL import Image
5
  from io import BytesIO
6
- from datasets import load_dataset
7
 
8
  # Define the template
9
  TEMPLATE = """[INST] Evaluate if the given text is harmful i.e. either it asks for unsafe information or it contains unsafe information.
@@ -29,12 +28,6 @@ def load_image_from_url(url):
29
  img = Image.open(BytesIO(response.content))
30
  return img
31
 
32
- # Load dataset
33
- @st.cache_data
34
- def load_example_dataset():
35
- ds = load_dataset("walledai/XSTest")
36
- return ds['train']['prompt'][:10] # Get first 10 examples
37
-
38
  # Evaluation function
39
  def evaluate_text(user_input):
40
  if user_input:
@@ -64,19 +57,8 @@ st.title("Text Safety Evaluator")
64
  if 'model_and_tokenizer' not in st.session_state:
65
  st.session_state.model_and_tokenizer = load_model()
66
 
67
- # Load example dataset
68
- example_prompts = load_example_dataset()
69
-
70
- # Display example prompts
71
- st.subheader("Example Inputs:")
72
- for i, prompt in enumerate(example_prompts):
73
- if st.button(f"Example {i+1}", key=f"example_{i}"):
74
- st.session_state.user_input = prompt
75
-
76
  # User input
77
- user_input = st.text_area("Enter the text you want to evaluate:",
78
- height=100,
79
- value=st.session_state.get('user_input', ''))
80
 
81
  # Create an empty container for the result
82
  result_container = st.empty()
@@ -90,17 +72,17 @@ if st.button("Evaluate"):
90
  result_container.warning("Please enter some text to evaluate.")
91
 
92
  # Add logo at the bottom center (only once)
93
- if 'logo_displayed' not in st.session_state:
94
- col1, col2, col3 = st.columns([1,2,1])
95
- with col2:
96
- logo_url = "https://github.com/walledai/walledeval/assets/32847115/d8b1d14f-7071-448b-8997-2eeba4c2c8f6"
97
- logo = load_image_from_url(logo_url)
98
- st.image(logo, use_column_width=True, width=500) # Adjust the width as needed
99
- st.session_state.logo_displayed = True
100
 
101
  # Add information about Walled Guard Advanced (only once)
102
- if 'info_displayed' not in st.session_state:
103
- col1, col2, col3 = st.columns([1,2,1])
104
- with col2:
105
- st.info("For a more performant version, check out Walled Guard Advanced. Connect with us at [email protected] for more information.")
106
- st.session_state.info_displayed = True
 
3
  import requests
4
  from PIL import Image
5
  from io import BytesIO
 
6
 
7
  # Define the template
8
  TEMPLATE = """[INST] Evaluate if the given text is harmful i.e. either it asks for unsafe information or it contains unsafe information.
 
28
  img = Image.open(BytesIO(response.content))
29
  return img
30
 
 
 
 
 
 
 
31
  # Evaluation function
32
  def evaluate_text(user_input):
33
  if user_input:
 
57
  if 'model_and_tokenizer' not in st.session_state:
58
  st.session_state.model_and_tokenizer = load_model()
59
 
 
 
 
 
 
 
 
 
 
60
  # User input
61
+ user_input = st.text_area("Enter the text you want to evaluate:", height=100)
 
 
62
 
63
  # Create an empty container for the result
64
  result_container = st.empty()
 
72
  result_container.warning("Please enter some text to evaluate.")
73
 
74
  # Add logo at the bottom center (only once)
75
+ #if 'logo_displayed' not in st.session_state:
76
+ col1, col2, col3 = st.columns([1,2,1])
77
+ with col2:
78
+ logo_url = "https://github.com/walledai/walledeval/assets/32847115/d8b1d14f-7071-448b-8997-2eeba4c2c8f6"
79
+ logo = load_image_from_url(logo_url)
80
+ st.image(logo, use_column_width=True, width=500) # Adjust the width as needed
81
+ #st.session_state.logo_displayed = True
82
 
83
  # Add information about Walled Guard Advanced (only once)
84
+ #if 'info_displayed' not in st.session_state:
85
+ col1, col2, col3 = st.columns([1,2,1])
86
+ with col2:
87
+ st.info("For a more performant version, check out Walled Guard Advanced. Connect with us at [email protected] for more information.")
88
+ #st.session_state.info_displayed = True