fffiloni commited on
Commit
558b476
·
verified ·
1 Parent(s): cf577fc

include safety in app

Browse files
Files changed (1) hide show
  1. app.py +35 -12
app.py CHANGED
@@ -5,19 +5,43 @@ import time
5
  import os
6
  import re
7
  from gradio_client import Client
 
 
8
 
9
- is_shared_ui = True if "fffiloni/consistent-character" in os.environ['SPACE_ID'] else False
10
- def safety_check(user_prompt, token):
 
 
 
 
 
 
 
 
 
11
 
12
- client = Client("fffiloni/safety-checker-bot", hf_token=token)
13
- response = client.predict(
14
- source_space="consistent-character space",
15
- user_prompt=user_prompt,
16
- api_name="/infer"
17
- )
18
- print(response)
19
 
20
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  from utils.gradio_helpers import parse_outputs, process_outputs
23
 
@@ -33,9 +57,8 @@ def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
33
 
34
  try:
35
  if is_shared_ui:
36
- hf_token = os.environ.get("HF_TOKEN")
37
 
38
- is_safe = safety_check(args[0], hf_token)
39
  print(is_safe)
40
 
41
  match = re.search(r'\bYes\b', is_safe)
 
5
  import os
6
  import re
7
  from gradio_client import Client
8
+ import torch
9
+ from transformers import pipeline
10
 
11
+ pipe_safety = pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-beta", torch_dtype=torch.bfloat16, device_map="auto")
12
+
13
+ agent_maker_sys = os.environ.get("SAFETY_PROMPT")
14
+
15
+ instruction = f"""
16
+ <|system|>
17
+ {agent_maker_sys}</s>
18
+ <|user|>
19
+ """
20
+
21
+ def safety_check(user_prompt):
22
 
23
+ prompt = f"{instruction.strip()}\n'{user_prompt}'</s>"
24
+ print(f"""
25
+
26
+ USER PROMPT: {user_prompt}
27
+ """)
28
+
29
+ outputs = pipe_safety(prompt, max_new_tokens=256, do_sample=True, temperature=0.3, top_k=50, top_p=0.95)
30
 
31
+
32
+ pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>'
33
+ cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL)
34
+
35
+ print(f"""
36
+
37
+ SAFETY COUNCIL: {cleaned_text}
38
+
39
+ """)
40
+
41
+ return cleaned_text.lstrip("\n")
42
+
43
+
44
+ is_shared_ui = True if "fffiloni/consistent-character" in os.environ['SPACE_ID'] else False
45
 
46
  from utils.gradio_helpers import parse_outputs, process_outputs
47
 
 
57
 
58
  try:
59
  if is_shared_ui:
 
60
 
61
+ is_safe = safety_check(args[0])
62
  print(is_safe)
63
 
64
  match = re.search(r'\bYes\b', is_safe)