juanpablomesa
commited on
Commit
·
0f7ecda
1
Parent(s):
a0a72c0
Added normalization using torch
Browse files- handler.py +4 -5
handler.py
CHANGED
@@ -118,14 +118,13 @@ class EndpointHandler:
|
|
118 |
self.logger.info("Squeezing tensor")
|
119 |
batch_emb = frame_embedding.squeeze(0)
|
120 |
|
|
|
|
|
|
|
|
|
121 |
self.logger.info("Converting into numpy array")
|
122 |
batch_emb = batch_emb.cpu().detach().numpy()
|
123 |
|
124 |
-
# NORMALIZE
|
125 |
-
# self.logger.info("Normalizing numpy array")
|
126 |
-
# batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1)
|
127 |
-
# transpose back to (21, 512)
|
128 |
-
|
129 |
self.logger.info("Converting to list")
|
130 |
batch_emb = batch_emb.tolist()
|
131 |
|
|
|
118 |
self.logger.info("Squeezing tensor")
|
119 |
batch_emb = frame_embedding.squeeze(0)
|
120 |
|
121 |
+
# Normalize the embeddings
|
122 |
+
self.logger.info("Normalizing embeddings")
|
123 |
+
batch_emb = torch.nn.functional.normalize(batch_emb, p=2, dim=1)
|
124 |
+
|
125 |
self.logger.info("Converting into numpy array")
|
126 |
batch_emb = batch_emb.cpu().detach().numpy()
|
127 |
|
|
|
|
|
|
|
|
|
|
|
128 |
self.logger.info("Converting to list")
|
129 |
batch_emb = batch_emb.tolist()
|
130 |
|