Spaces:
Runtime error
Runtime error
Hugo Flores
commited on
Commit
·
326b5bb
1
Parent(s):
04c5b94
fix: sample prefix suffix
Browse files- scripts/exp/train.py +4 -3
scripts/exp/train.py
CHANGED
|
@@ -216,6 +216,7 @@ def accuracy(
|
|
| 216 |
return accuracy
|
| 217 |
|
| 218 |
def sample_prefix_suffix_amt(
|
|
|
|
| 219 |
n_batch,
|
| 220 |
prefix_amt,
|
| 221 |
suffix_amt,
|
|
@@ -362,7 +363,7 @@ def train(
|
|
| 362 |
n_batch = z.shape[0]
|
| 363 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
| 364 |
|
| 365 |
-
n_prefix, n_suffix = sample_prefix_suffix_amt(
|
| 366 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
| 367 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
| 368 |
rng=rng
|
|
@@ -448,7 +449,7 @@ def train(
|
|
| 448 |
n_batch = z.shape[0]
|
| 449 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
| 450 |
|
| 451 |
-
n_prefix, n_suffix = sample_prefix_suffix_amt(
|
| 452 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
| 453 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
| 454 |
rng=rng
|
|
@@ -606,7 +607,7 @@ def train(
|
|
| 606 |
|
| 607 |
n_batch = z.shape[0]
|
| 608 |
|
| 609 |
-
n_prefix, n_suffix = sample_prefix_suffix_amt(
|
| 610 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
| 611 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
| 612 |
rng=rng
|
|
|
|
| 216 |
return accuracy
|
| 217 |
|
| 218 |
def sample_prefix_suffix_amt(
|
| 219 |
+
z,
|
| 220 |
n_batch,
|
| 221 |
prefix_amt,
|
| 222 |
suffix_amt,
|
|
|
|
| 363 |
n_batch = z.shape[0]
|
| 364 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
| 365 |
|
| 366 |
+
n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
|
| 367 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
| 368 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
| 369 |
rng=rng
|
|
|
|
| 449 |
n_batch = z.shape[0]
|
| 450 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
| 451 |
|
| 452 |
+
n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
|
| 453 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
| 454 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
| 455 |
rng=rng
|
|
|
|
| 607 |
|
| 608 |
n_batch = z.shape[0]
|
| 609 |
|
| 610 |
+
n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
|
| 611 |
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
| 612 |
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
| 613 |
rng=rng
|