juanpablomesa
commited on
Commit
·
dae7ab5
1
Parent(s):
073334d
Modified the squeeze() application to embeddings
Browse files- handler.py +6 -5
handler.py
CHANGED
@@ -115,18 +115,19 @@ class EndpointHandler:
|
|
115 |
# tensor = torch.stack(frame_embedding)
|
116 |
|
117 |
# detach text emb from graph, move to CPU, and convert to numpy array
|
118 |
-
self.logger.info("Squeezing tensor")
|
119 |
-
batch_emb = frame_embedding.squeeze(0)
|
120 |
|
121 |
# Check the shape of the tensor
|
122 |
-
self.logger.info(f"Shape of the batch_emb tensor: {
|
123 |
|
124 |
# Normalize the embeddings if it's a 2D tensor
|
125 |
-
if
|
126 |
self.logger.info("Normalizing embeddings")
|
127 |
-
batch_emb = torch.nn.functional.normalize(
|
128 |
else:
|
129 |
self.logger.info("Skipping normalization due to tensor shape")
|
|
|
130 |
|
131 |
self.logger.info("Converting into numpy array")
|
132 |
batch_emb = batch_emb.cpu().detach().numpy()
|
|
|
115 |
# tensor = torch.stack(frame_embedding)
|
116 |
|
117 |
# detach text emb from graph, move to CPU, and convert to numpy array
|
118 |
+
# self.logger.info("Squeezing tensor")
|
119 |
+
# batch_emb = frame_embedding.squeeze(0)
|
120 |
|
121 |
# Check the shape of the tensor
|
122 |
+
self.logger.info(f"Shape of the batch_emb tensor: {frame_embedding.shape}")
|
123 |
|
124 |
# Normalize the embeddings if it's a 2D tensor
|
125 |
+
if frame_embedding.dim() == 2:
|
126 |
self.logger.info("Normalizing embeddings")
|
127 |
+
batch_emb = torch.nn.functional.normalize(frame_embedding, p=2, dim=1)
|
128 |
else:
|
129 |
self.logger.info("Skipping normalization due to tensor shape")
|
130 |
+
batch_emb = frame_embedding.squeeze(0)
|
131 |
|
132 |
self.logger.info("Converting into numpy array")
|
133 |
batch_emb = batch_emb.cpu().detach().numpy()
|