Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -33,7 +33,7 @@ from PIL import Image
|
|
33 |
from torchvision.utils import make_grid, save_image
|
34 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
35 |
|
36 |
-
from app import safety_check
|
37 |
from app.sana_pipeline import SanaPipeline
|
38 |
|
39 |
MAX_SEED = np.iinfo(np.int32).max
|
@@ -205,12 +205,12 @@ if torch.cuda.is_available():
|
|
205 |
pipe.register_progress_bar(gr.Progress())
|
206 |
|
207 |
# safety checker
|
208 |
-
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
209 |
-
safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
).to(device)
|
214 |
|
215 |
|
216 |
def save_image_sana(img, seed="", save_img=False):
|
@@ -254,8 +254,8 @@ def generate(
|
|
254 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
255 |
generator = torch.Generator(device=device).manual_seed(seed)
|
256 |
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
257 |
-
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
258 |
-
|
259 |
|
260 |
print(prompt)
|
261 |
|
|
|
33 |
from torchvision.utils import make_grid, save_image
|
34 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
35 |
|
36 |
+
#from app import safety_check
|
37 |
from app.sana_pipeline import SanaPipeline
|
38 |
|
39 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
205 |
pipe.register_progress_bar(gr.Progress())
|
206 |
|
207 |
# safety checker
|
208 |
+
#safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
|
209 |
+
#safety_checker_model = AutoModelForCausalLM.from_pretrained(
|
210 |
+
# args.shield_model_path,
|
211 |
+
# device_map="auto",
|
212 |
+
# torch_dtype=torch.bfloat16,
|
213 |
+
#).to(device)
|
214 |
|
215 |
|
216 |
def save_image_sana(img, seed="", save_img=False):
|
|
|
254 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
255 |
generator = torch.Generator(device=device).manual_seed(seed)
|
256 |
print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
|
257 |
+
#if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
|
258 |
+
# prompt = "A red heart."
|
259 |
|
260 |
print(prompt)
|
261 |
|