Spaces:
Runtime error
Runtime error
hysts
commited on
Commit
·
552d2be
1
Parent(s):
2f2bde2
Update logger
Browse files
model.py
CHANGED
@@ -215,8 +215,7 @@ class Model:
|
|
215 |
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
216 |
|
217 |
elapsed = time.perf_counter() - start
|
218 |
-
logger.info(f'
|
219 |
-
logger.info('--- done ---')
|
220 |
return model, args
|
221 |
|
222 |
def load_strategy(self) -> CoglmStrategy:
|
@@ -230,8 +229,7 @@ class Model:
|
|
230 |
top_k_cluster=self.args.temp_cluster_gen)
|
231 |
|
232 |
elapsed = time.perf_counter() - start
|
233 |
-
logger.info(f'
|
234 |
-
logger.info('--- done ---')
|
235 |
return strategy
|
236 |
|
237 |
def load_srg(self) -> SRGroup:
|
@@ -241,8 +239,7 @@ class Model:
|
|
241 |
srg = None if self.args.only_first_stage else SRGroup(self.args)
|
242 |
|
243 |
elapsed = time.perf_counter() - start
|
244 |
-
logger.info(f'
|
245 |
-
logger.info('--- done ---')
|
246 |
return srg
|
247 |
|
248 |
def update_style(self, style: str) -> None:
|
@@ -267,8 +264,7 @@ class Model:
|
|
267 |
self.srg.itersr.strategy.topk = self.args.topk_itersr
|
268 |
|
269 |
elapsed = time.perf_counter() - start
|
270 |
-
logger.info(f'
|
271 |
-
logger.info('--- done ---')
|
272 |
|
273 |
def run(self, text: str, style: str, seed: int, only_first_stage: bool,
|
274 |
num: int) -> list[np.ndarray] | None:
|
@@ -306,8 +302,7 @@ class Model:
|
|
306 |
seq = torch.tensor(seq + [-1] * 400, device=self.device)
|
307 |
|
308 |
elapsed = time.perf_counter() - start
|
309 |
-
logger.info(f'
|
310 |
-
logger.info('--- done ---')
|
311 |
return seq, txt_len
|
312 |
|
313 |
@torch.inference_mode()
|
@@ -345,8 +340,7 @@ class Model:
|
|
345 |
logger.debug(f'{output_tokens.shape=}')
|
346 |
|
347 |
elapsed = time.perf_counter() - start
|
348 |
-
logger.info(f'
|
349 |
-
logger.info('--- done ---')
|
350 |
return output_tokens
|
351 |
|
352 |
@staticmethod
|
@@ -380,8 +374,7 @@ class Model:
|
|
380 |
res.append(decoded_img) # only the last image (target)
|
381 |
|
382 |
elapsed = time.perf_counter() - start
|
383 |
-
logger.info(f'
|
384 |
-
logger.info('--- done ---')
|
385 |
return res
|
386 |
|
387 |
|
|
|
215 |
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
|
216 |
|
217 |
elapsed = time.perf_counter() - start
|
218 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
219 |
return model, args
|
220 |
|
221 |
def load_strategy(self) -> CoglmStrategy:
|
|
|
229 |
top_k_cluster=self.args.temp_cluster_gen)
|
230 |
|
231 |
elapsed = time.perf_counter() - start
|
232 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
233 |
return strategy
|
234 |
|
235 |
def load_srg(self) -> SRGroup:
|
|
|
239 |
srg = None if self.args.only_first_stage else SRGroup(self.args)
|
240 |
|
241 |
elapsed = time.perf_counter() - start
|
242 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
243 |
return srg
|
244 |
|
245 |
def update_style(self, style: str) -> None:
|
|
|
264 |
self.srg.itersr.strategy.topk = self.args.topk_itersr
|
265 |
|
266 |
elapsed = time.perf_counter() - start
|
267 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
268 |
|
269 |
def run(self, text: str, style: str, seed: int, only_first_stage: bool,
|
270 |
num: int) -> list[np.ndarray] | None:
|
|
|
302 |
seq = torch.tensor(seq + [-1] * 400, device=self.device)
|
303 |
|
304 |
elapsed = time.perf_counter() - start
|
305 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
306 |
return seq, txt_len
|
307 |
|
308 |
@torch.inference_mode()
|
|
|
340 |
logger.debug(f'{output_tokens.shape=}')
|
341 |
|
342 |
elapsed = time.perf_counter() - start
|
343 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
344 |
return output_tokens
|
345 |
|
346 |
@staticmethod
|
|
|
374 |
res.append(decoded_img) # only the last image (target)
|
375 |
|
376 |
elapsed = time.perf_counter() - start
|
377 |
+
logger.info(f'--- done ({elapsed=:.3f} ---')
|
|
|
378 |
return res
|
379 |
|
380 |
|