Spaces:
Running
Running
gabrielaltay
commited on
Commit
•
794de9b
1
Parent(s):
f693380
take care of bfloat16 conversions
Browse files
app.py
CHANGED
@@ -74,7 +74,10 @@ def get_device():
|
|
74 |
|
75 |
def get_dtype(device: torch.device):
|
76 |
if device == torch.device("cuda"):
|
77 |
-
|
|
|
|
|
|
|
78 |
elif device == torch.device("mps"):
|
79 |
dtype = torch.float32
|
80 |
else:
|
@@ -114,7 +117,7 @@ def embed_page_images(model, processor, page_images, batch_size=1):
|
|
114 |
embeddings = model(**batch)
|
115 |
page_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
|
116 |
pbar.progress((ibatch + 1) / len(page_images), text="embedding pages")
|
117 |
-
return np.array(page_embeddings)
|
118 |
|
119 |
|
120 |
def embed_query_texts(model, processor, query_texts, batch_size=1):
|
@@ -132,7 +135,7 @@ def embed_query_texts(model, processor, query_texts, batch_size=1):
|
|
132 |
batch = {k: v.to(model.device) for k, v in batch.items()}
|
133 |
embeddings = model(**batch)
|
134 |
query_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
|
135 |
-
return np.array(query_embeddings)[0]
|
136 |
|
137 |
|
138 |
def get_pdf_page_images_from_bytes(
|
|
|
74 |
|
75 |
def get_dtype(device: torch.device):
|
76 |
if device == torch.device("cuda"):
|
77 |
+
if torch.cuda.is_bf16_supported():
|
78 |
+
dtype = torch.bfloat16
|
79 |
+
else:
|
80 |
+
dtype = torch.float16
|
81 |
elif device == torch.device("mps"):
|
82 |
dtype = torch.float32
|
83 |
else:
|
|
|
117 |
embeddings = model(**batch)
|
118 |
page_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
|
119 |
pbar.progress((ibatch + 1) / len(page_images), text="embedding pages")
|
120 |
+
return np.array([el.to(torch.float32) for el in page_embeddings])
|
121 |
|
122 |
|
123 |
def embed_query_texts(model, processor, query_texts, batch_size=1):
|
|
|
135 |
batch = {k: v.to(model.device) for k, v in batch.items()}
|
136 |
embeddings = model(**batch)
|
137 |
query_embeddings.extend(list(torch.unbind(embeddings.to("cpu"))))
|
138 |
+
return np.array([el.to(torch.float32) for el in query_embeddings])[0]
|
139 |
|
140 |
|
141 |
def get_pdf_page_images_from_bytes(
|