juanpablomesa commited on
Commit
dae7ab5
·
1 Parent(s): 073334d

Modified the squeeze() application to embeddings

Browse files
Files changed (1) hide show
  1. handler.py +6 -5
handler.py CHANGED
@@ -115,18 +115,19 @@ class EndpointHandler:
115
  # tensor = torch.stack(frame_embedding)
116
 
117
  # detach text emb from graph, move to CPU, and convert to numpy array
118
- self.logger.info("Squeezing tensor")
119
- batch_emb = frame_embedding.squeeze(0)
120
 
121
  # Check the shape of the tensor
122
- self.logger.info(f"Shape of the batch_emb tensor: {batch_emb.shape}")
123
 
124
  # Normalize the embeddings if it's a 2D tensor
125
- if batch_emb.dim() == 2:
126
  self.logger.info("Normalizing embeddings")
127
- batch_emb = torch.nn.functional.normalize(batch_emb, p=2, dim=1)
128
  else:
129
  self.logger.info("Skipping normalization due to tensor shape")
 
130
 
131
  self.logger.info("Converting into numpy array")
132
  batch_emb = batch_emb.cpu().detach().numpy()
 
115
  # tensor = torch.stack(frame_embedding)
116
 
117
  # detach text emb from graph, move to CPU, and convert to numpy array
118
+ # self.logger.info("Squeezing tensor")
119
+ # batch_emb = frame_embedding.squeeze(0)
120
 
121
  # Check the shape of the tensor
122
+ self.logger.info(f"Shape of the batch_emb tensor: {frame_embedding.shape}")
123
 
124
  # Normalize the embeddings if it's a 2D tensor
125
+ if frame_embedding.dim() == 2:
126
  self.logger.info("Normalizing embeddings")
127
+ batch_emb = torch.nn.functional.normalize(frame_embedding, p=2, dim=1)
128
  else:
129
  self.logger.info("Skipping normalization due to tensor shape")
130
+ batch_emb = frame_embedding.squeeze(0)
131
 
132
  self.logger.info("Converting into numpy array")
133
  batch_emb = batch_emb.cpu().detach().numpy()