Adrian Cowham commited on
Commit
c99fd41
·
1 Parent(s): 1e894a3

device conditionals

Browse files
Files changed (1) hide show
  1. src/core/embedding.py +7 -0
src/core/embedding.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import List, Type
2
 
 
3
  from langchain.docstore.document import Document
4
  from langchain.embeddings import HuggingFaceEmbeddings
5
  from langchain.embeddings.base import Embeddings
@@ -49,7 +50,13 @@ class FolderIndex:
49
 
50
  def embed_files(files: List[File]) -> FolderIndex:
51
  model_name = "adriancowham/letstalk-mythomax-embed-gte-small"
 
52
  model_kwargs = {'device': 'cpu'}
 
 
 
 
 
53
  encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
54
  print("Loading model...")
55
  try:
 
1
  from typing import List, Type
2
 
3
+ import torch
4
  from langchain.docstore.document import Document
5
  from langchain.embeddings import HuggingFaceEmbeddings
6
  from langchain.embeddings.base import Embeddings
 
50
 
51
  def embed_files(files: List[File]) -> FolderIndex:
52
  model_name = "adriancowham/letstalk-mythomax-embed-gte-small"
53
+
54
  model_kwargs = {'device': 'cpu'}
55
+ if torch.cuda.is_available():
56
+ model_kwargs['device'] = 'cuda'
57
+ if torch.backends.mps.is_available():
58
+ model_kwargs['device'] = 'mps'
59
+
60
  encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
61
  print("Loading model...")
62
  try: