Spaces:
Runtime error
Runtime error
Adrian Cowham
commited on
Commit
·
c99fd41
1
Parent(s):
1e894a3
device conditionals
Browse files- src/core/embedding.py +7 -0
src/core/embedding.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from typing import List, Type
|
2 |
|
|
|
3 |
from langchain.docstore.document import Document
|
4 |
from langchain.embeddings import HuggingFaceEmbeddings
|
5 |
from langchain.embeddings.base import Embeddings
|
@@ -49,7 +50,13 @@ class FolderIndex:
|
|
49 |
|
50 |
def embed_files(files: List[File]) -> FolderIndex:
|
51 |
model_name = "adriancowham/letstalk-mythomax-embed-gte-small"
|
|
|
52 |
model_kwargs = {'device': 'cpu'}
|
|
|
|
|
|
|
|
|
|
|
53 |
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
|
54 |
print("Loading model...")
|
55 |
try:
|
|
|
1 |
from typing import List, Type
|
2 |
|
3 |
+
import torch
|
4 |
from langchain.docstore.document import Document
|
5 |
from langchain.embeddings import HuggingFaceEmbeddings
|
6 |
from langchain.embeddings.base import Embeddings
|
|
|
50 |
|
51 |
def embed_files(files: List[File]) -> FolderIndex:
|
52 |
model_name = "adriancowham/letstalk-mythomax-embed-gte-small"
|
53 |
+
|
54 |
model_kwargs = {'device': 'cpu'}
|
55 |
+
if torch.cuda.is_available():
|
56 |
+
model_kwargs['device'] = 'cuda'
|
57 |
+
if torch.backends.mps.is_available():
|
58 |
+
model_kwargs['device'] = 'mps'
|
59 |
+
|
60 |
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
|
61 |
print("Loading model...")
|
62 |
try:
|