juanpablomesa commited on
Commit
0f7ecda
·
1 Parent(s): a0a72c0

Added normalization using torch

Browse files
Files changed (1) hide show
  1. handler.py +4 -5
handler.py CHANGED
@@ -118,14 +118,13 @@ class EndpointHandler:
118
  self.logger.info("Squeezing tensor")
119
  batch_emb = frame_embedding.squeeze(0)
120
 
 
 
 
 
121
  self.logger.info("Converting into numpy array")
122
  batch_emb = batch_emb.cpu().detach().numpy()
123
 
124
- # NORMALIZE
125
- # self.logger.info("Normalizing numpy array")
126
- # batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1)
127
- # transpose back to (21, 512)
128
-
129
  self.logger.info("Converting to list")
130
  batch_emb = batch_emb.tolist()
131
 
 
118
  self.logger.info("Squeezing tensor")
119
  batch_emb = frame_embedding.squeeze(0)
120
 
121
+ # Normalize the embeddings
122
+ self.logger.info("Normalizing embeddings")
123
+ batch_emb = torch.nn.functional.normalize(batch_emb, p=2, dim=1)
124
+
125
  self.logger.info("Converting into numpy array")
126
  batch_emb = batch_emb.cpu().detach().numpy()
127
 
 
 
 
 
 
128
  self.logger.info("Converting to list")
129
  batch_emb = batch_emb.tolist()
130