Spaces:
Running
on
A10G
Running
on
A10G
rynmurdock
commited on
Commit
•
86d2837
1
Parent(s):
9c7e8e1
limit to 10 rows from 1 user for diversity.
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
|
2 |
|
3 |
-
|
4 |
# TODO save & restart from (if it exists) dataframe parquet
|
5 |
import torch
|
6 |
|
@@ -37,12 +37,9 @@ torch.set_grad_enabled(False)
|
|
37 |
torch.backends.cuda.matmul.allow_tf32 = True
|
38 |
torch.backends.cudnn.allow_tf32 = True
|
39 |
|
40 |
-
prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
|
41 |
|
42 |
import spaces
|
43 |
-
prompt_list = [p for p in list(set(
|
44 |
-
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
45 |
-
|
46 |
start_time = time.time()
|
47 |
|
48 |
####################### Setup Model
|
@@ -55,13 +52,13 @@ from transformers import CLIPVisionModelWithProjection
|
|
55 |
import uuid
|
56 |
import av
|
57 |
|
58 |
-
def
|
59 |
print('Saving')
|
60 |
container = av.open(file_name, mode="w")
|
61 |
|
62 |
stream = container.add_stream("h264", rate=fps)
|
63 |
# stream.options = {'preset': 'faster'}
|
64 |
-
stream.thread_count =
|
65 |
stream.width = 512
|
66 |
stream.height = 512
|
67 |
stream.pix_fmt = "yuv420p"
|
@@ -79,8 +76,16 @@ def write_video(file_name, images, fps=17):
|
|
79 |
container.close()
|
80 |
print('Saved')
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype
|
|
|
84 |
#vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
|
85 |
|
86 |
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
@@ -91,8 +96,9 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter",
|
|
91 |
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
92 |
|
93 |
|
94 |
-
unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet').to(dtype)
|
95 |
-
text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='text_encoder'
|
|
|
96 |
|
97 |
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
98 |
pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, unet=unet, text_encoder=text_encoder)
|
@@ -101,6 +107,7 @@ pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_
|
|
101 |
pipe.set_adapters(["lcm-lora"], [.9])
|
102 |
pipe.fuse_lora()
|
103 |
|
|
|
104 |
#pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
|
105 |
#pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
106 |
#repo = "ByteDance/AnimateDiff-Lightning"
|
@@ -116,8 +123,7 @@ pipe.unet.fuse_qkv_projections()
|
|
116 |
pipe.to(device=DEVICE)
|
117 |
#pipe.unet = torch.compile(pipe.unet)
|
118 |
#pipe.vae = torch.compile(pipe.vae)
|
119 |
-
|
120 |
-
|
121 |
#im_embs = torch.zeros(1, 1, 1, 1280, device=DEVICE, dtype=dtype)
|
122 |
#output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
|
123 |
#leave_im_emb, _ = pipe.encode_image(
|
@@ -126,13 +132,13 @@ pipe.to(device=DEVICE)
|
|
126 |
#assert len(output.frames[0]) == 16
|
127 |
#leave_im_emb.detach().to('cpu')
|
128 |
|
129 |
-
@spaces.GPU(duration=
|
130 |
def generate_gpu(in_im_embs):
|
131 |
print('start gen')
|
132 |
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
133 |
#im_embs = torch.cat((torch.zeros(1, 1280, device=DEVICE, dtype=dtype), in_im_embs), 0)
|
134 |
|
135 |
-
output = pipe(prompt='
|
136 |
print('image is made')
|
137 |
im_emb, _ = pipe.encode_image(
|
138 |
output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
|
@@ -163,10 +169,6 @@ def generate(in_im_embs):
|
|
163 |
|
164 |
#######################
|
165 |
|
166 |
-
|
167 |
-
# TODO only generate ~5 new images ahead from a specific user embedding. Do this by tracking a column of who's embedding it was and
|
168 |
-
# taking the intersection for unrated by that user and from that users' embedding. Then we keep styles less consistent for better variety.
|
169 |
-
|
170 |
def get_user_emb(embs, ys):
|
171 |
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
172 |
if len(list(set(ys))) <= 1:
|
@@ -245,7 +247,17 @@ def background_next_image():
|
|
245 |
for uid in user_id_list:
|
246 |
rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is not None for i in prevs_df.iterrows()]]
|
247 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is None for i in prevs_df.iterrows()]]
|
248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
print(f'latest user {uid} has < 4 rows') # or > 7 unrated rows')
|
250 |
continue
|
251 |
|
@@ -260,6 +272,7 @@ def background_next_image():
|
|
260 |
tmp_df['paths'] = [img]
|
261 |
tmp_df['embeddings'] = [embs]
|
262 |
tmp_df['user:rating'] = [{' ': ' '}]
|
|
|
263 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
264 |
# we can free up storage by deleting the image
|
265 |
if len(prevs_df) > 50:
|
@@ -345,7 +358,9 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
|
345 |
choice = 0
|
346 |
|
347 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
348 |
-
|
|
|
|
|
349 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
350 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
351 |
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
@@ -411,6 +426,7 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
411 |
''', elem_id="description")
|
412 |
user_id = gr.State()
|
413 |
print('USER_ID: ',user_id)
|
|
|
414 |
calibrate_prompts = gr.State([
|
415 |
'./first.mp4',
|
416 |
'./second.mp4',
|
@@ -429,7 +445,7 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
429 |
interactive=False,
|
430 |
height=512,
|
431 |
width=512,
|
432 |
-
include_audio=False,
|
433 |
elem_id="video_output"
|
434 |
)
|
435 |
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
@@ -471,12 +487,12 @@ log = logging.getLogger('log_here')
|
|
471 |
log.setLevel(logging.ERROR)
|
472 |
|
473 |
scheduler = BackgroundScheduler()
|
474 |
-
scheduler.add_job(func=background_next_image, trigger="interval", seconds
|
475 |
scheduler.start()
|
476 |
|
477 |
def encode_space(x):
|
478 |
im_emb, _ = pipe.encode_image(
|
479 |
-
image,
|
480 |
)
|
481 |
return im_emb.detach().to('cpu').to(torch.float32)
|
482 |
|
|
|
1 |
|
2 |
|
3 |
+
# TODO unify/merge origin and this
|
4 |
# TODO save & restart from (if it exists) dataframe parquet
|
5 |
import torch
|
6 |
|
|
|
37 |
torch.backends.cuda.matmul.allow_tf32 = True
|
38 |
torch.backends.cudnn.allow_tf32 = True
|
39 |
|
40 |
+
prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id'])
|
41 |
|
42 |
import spaces
|
|
|
|
|
|
|
43 |
start_time = time.time()
|
44 |
|
45 |
####################### Setup Model
|
|
|
52 |
import uuid
|
53 |
import av
|
54 |
|
55 |
+
def write_video_av(file_name, images, fps=17):
|
56 |
print('Saving')
|
57 |
container = av.open(file_name, mode="w")
|
58 |
|
59 |
stream = container.add_stream("h264", rate=fps)
|
60 |
# stream.options = {'preset': 'faster'}
|
61 |
+
stream.thread_count = -1
|
62 |
stream.width = 512
|
63 |
stream.height = 512
|
64 |
stream.pix_fmt = "yuv420p"
|
|
|
76 |
container.close()
|
77 |
print('Saved')
|
78 |
|
79 |
+
def write_video(file_name, images, fps=15):
|
80 |
+
writer = imageio.get_writer(file_name, fps=fps)
|
81 |
+
|
82 |
+
for im in images:
|
83 |
+
writer.append_data(np.array(im))
|
84 |
+
writer.close()
|
85 |
+
|
86 |
|
87 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype,
|
88 |
+
device_map='cpu')
|
89 |
#vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
|
90 |
|
91 |
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
|
|
96 |
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
97 |
|
98 |
|
99 |
+
unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
|
100 |
+
text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='text_encoder',
|
101 |
+
device_map='cpu').to(dtype)
|
102 |
|
103 |
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
104 |
pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, unet=unet, text_encoder=text_encoder)
|
|
|
107 |
pipe.set_adapters(["lcm-lora"], [.9])
|
108 |
pipe.fuse_lora()
|
109 |
|
110 |
+
|
111 |
#pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
|
112 |
#pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
113 |
#repo = "ByteDance/AnimateDiff-Lightning"
|
|
|
123 |
pipe.to(device=DEVICE)
|
124 |
#pipe.unet = torch.compile(pipe.unet)
|
125 |
#pipe.vae = torch.compile(pipe.vae)
|
126 |
+
# TODO cannot compile on Spaces or we time out; don't run leave_imb stuff either
|
|
|
127 |
#im_embs = torch.zeros(1, 1, 1, 1280, device=DEVICE, dtype=dtype)
|
128 |
#output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
|
129 |
#leave_im_emb, _ = pipe.encode_image(
|
|
|
132 |
#assert len(output.frames[0]) == 16
|
133 |
#leave_im_emb.detach().to('cpu')
|
134 |
|
135 |
+
@spaces.GPU(duration=10)
|
136 |
def generate_gpu(in_im_embs):
|
137 |
print('start gen')
|
138 |
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
139 |
#im_embs = torch.cat((torch.zeros(1, 1280, device=DEVICE, dtype=dtype), in_im_embs), 0)
|
140 |
|
141 |
+
output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
142 |
print('image is made')
|
143 |
im_emb, _ = pipe.encode_image(
|
144 |
output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
|
|
|
169 |
|
170 |
#######################
|
171 |
|
|
|
|
|
|
|
|
|
172 |
def get_user_emb(embs, ys):
|
173 |
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
174 |
if len(list(set(ys))) <= 1:
|
|
|
247 |
for uid in user_id_list:
|
248 |
rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is not None for i in prevs_df.iterrows()]]
|
249 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is None for i in prevs_df.iterrows()]]
|
250 |
+
|
251 |
+
# we need to intersect not_rated_rows from this user's embed > 7. Just add a new column on which user_id spawned the
|
252 |
+
# media.
|
253 |
+
|
254 |
+
from_user = prevs_df[[i[1]['from_user_id'] == uid for i in prevs_df.iterrows()]]
|
255 |
+
if len(from_user) >= 10:
|
256 |
+
oldest = from_user.iloc[-1]['paths']
|
257 |
+
print(f'User has {len(from_user)} rows. Popping oldest: {oldest}')
|
258 |
+
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
259 |
+
|
260 |
+
if len(rated_rows) < 4:
|
261 |
print(f'latest user {uid} has < 4 rows') # or > 7 unrated rows')
|
262 |
continue
|
263 |
|
|
|
272 |
tmp_df['paths'] = [img]
|
273 |
tmp_df['embeddings'] = [embs]
|
274 |
tmp_df['user:rating'] = [{' ': ' '}]
|
275 |
+
tmp_df['from_user_id'] = [uid]
|
276 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
277 |
# we can free up storage by deleting the image
|
278 |
if len(prevs_df) > 50:
|
|
|
358 |
choice = 0
|
359 |
|
360 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
361 |
+
|
362 |
+
|
363 |
+
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
364 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
365 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
366 |
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
|
|
426 |
''', elem_id="description")
|
427 |
user_id = gr.State()
|
428 |
print('USER_ID: ',user_id)
|
429 |
+
# calibration videos -- this is a misnomer now :D
|
430 |
calibrate_prompts = gr.State([
|
431 |
'./first.mp4',
|
432 |
'./second.mp4',
|
|
|
445 |
interactive=False,
|
446 |
height=512,
|
447 |
width=512,
|
448 |
+
#include_audio=False,
|
449 |
elem_id="video_output"
|
450 |
)
|
451 |
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
|
|
487 |
log.setLevel(logging.ERROR)
|
488 |
|
489 |
scheduler = BackgroundScheduler()
|
490 |
+
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.1)
|
491 |
scheduler.start()
|
492 |
|
493 |
def encode_space(x):
|
494 |
im_emb, _ = pipe.encode_image(
|
495 |
+
image, DEVICE, 1, output_hidden_state
|
496 |
)
|
497 |
return im_emb.detach().to('cpu').to(torch.float32)
|
498 |
|