zzc0208 commited on
Commit
161e8c8
·
verified ·
1 Parent(s): 8f56c5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -33,7 +33,7 @@ from PIL import Image
33
  from torchvision.utils import make_grid, save_image
34
  from transformers import AutoModelForCausalLM, AutoTokenizer
35
 
36
- from app import safety_check
37
  from app.sana_pipeline import SanaPipeline
38
 
39
  MAX_SEED = np.iinfo(np.int32).max
@@ -205,12 +205,12 @@ if torch.cuda.is_available():
205
  pipe.register_progress_bar(gr.Progress())
206
 
207
  # safety checker
208
- safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
209
- safety_checker_model = AutoModelForCausalLM.from_pretrained(
210
- args.shield_model_path,
211
- device_map="auto",
212
- torch_dtype=torch.bfloat16,
213
- ).to(device)
214
 
215
 
216
  def save_image_sana(img, seed="", save_img=False):
@@ -254,8 +254,8 @@ def generate(
254
  seed = int(randomize_seed_fn(seed, randomize_seed))
255
  generator = torch.Generator(device=device).manual_seed(seed)
256
  print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
257
- if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
258
- prompt = "A red heart."
259
 
260
  print(prompt)
261
 
 
33
  from torchvision.utils import make_grid, save_image
34
  from transformers import AutoModelForCausalLM, AutoTokenizer
35
 
36
+ #from app import safety_check
37
  from app.sana_pipeline import SanaPipeline
38
 
39
  MAX_SEED = np.iinfo(np.int32).max
 
205
  pipe.register_progress_bar(gr.Progress())
206
 
207
  # safety checker
208
+ #safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
209
+ #safety_checker_model = AutoModelForCausalLM.from_pretrained(
210
+ # args.shield_model_path,
211
+ # device_map="auto",
212
+ # torch_dtype=torch.bfloat16,
213
+ #).to(device)
214
 
215
 
216
  def save_image_sana(img, seed="", save_img=False):
 
254
  seed = int(randomize_seed_fn(seed, randomize_seed))
255
  generator = torch.Generator(device=device).manual_seed(seed)
256
  print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
257
+ #if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
258
+ # prompt = "A red heart."
259
 
260
  print(prompt)
261