gabrielaltay commited on
Commit
794de9b
1 Parent(s): f693380

take care of bfloat16 conversions

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -74,7 +74,10 @@ def get_device():
74
 
75
  def get_dtype(device: torch.device):
76
  if device == torch.device("cuda"):
77
- dtype = torch.bfloat16
 
 
 
78
  elif device == torch.device("mps"):
79
  dtype = torch.float32
80
  else:
@@ -114,7 +117,7 @@ def embed_page_images(model, processor, page_images, batch_size=1):
114
  embeddings = model(**batch)
115
  page_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
116
  pbar.progress((ibatch + 1) / len(page_images), text="embedding pages")
117
- return np.array(page_embeddings)
118
 
119
 
120
  def embed_query_texts(model, processor, query_texts, batch_size=1):
@@ -132,7 +135,7 @@ def embed_query_texts(model, processor, query_texts, batch_size=1):
132
  batch = {k: v.to(model.device) for k, v in batch.items()}
133
  embeddings = model(**batch)
134
  query_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
135
- return np.array(query_embeddings)[0]
136
 
137
 
138
  def get_pdf_page_images_from_bytes(
 
74
 
75
  def get_dtype(device: torch.device):
76
  if device == torch.device("cuda"):
77
+ if torch.cuda.is_bf16_supported():
78
+ dtype = torch.bfloat16
79
+ else:
80
+ dtype = torch.float16
81
  elif device == torch.device("mps"):
82
  dtype = torch.float32
83
  else:
 
117
  embeddings = model(**batch)
118
  page_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
119
  pbar.progress((ibatch + 1) / len(page_images), text="embedding pages")
120
+ return np.array([el.to(torch.float32) for el in page_embeddings])
121
 
122
 
123
  def embed_query_texts(model, processor, query_texts, batch_size=1):
 
135
  batch = {k: v.to(model.device) for k, v in batch.items()}
136
  embeddings = model(**batch)
137
  query_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
138
+ return np.array([el.to(torch.float32) for el in query_embeddings])[0]
139
 
140
 
141
  def get_pdf_page_images_from_bytes(