Spaces:
Running
Running
chore: Update device for HuggingFaceBgeEmbeddings to dynamic device selection
Browse files
app.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import os
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
from dotenv import load_dotenv
|
5 |
from langchain.callbacks.base import BaseCallbackHandler
|
6 |
from langchain.embeddings import CacheBackedEmbeddings
|
7 |
-
from
|
8 |
from langchain.storage import LocalFileStore
|
9 |
from langchain_anthropic import ChatAnthropic
|
10 |
from langchain_community.chat_models import ChatOllama
|
@@ -100,11 +101,25 @@ print(f"๋ถํ ๋ .ipynb ํ์ผ์ ๊ฐ์: {len(ipynb_docs)}")
|
|
100 |
combined_documents = py_docs + mdx_docs + ipynb_docs
|
101 |
print(f"์ด ๋ํ๋จผํธ ๊ฐ์: {len(combined_documents)}")
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
# Initialize embeddings and cache
|
104 |
store = LocalFileStore("~/.cache/embedding")
|
105 |
embeddings = HuggingFaceBgeEmbeddings(
|
106 |
model_name="BAAI/bge-m3",
|
107 |
-
model_kwargs={"device":
|
108 |
encode_kwargs={"normalize_embeddings": True},
|
109 |
)
|
110 |
cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
|
|
|
1 |
import os
|
2 |
+
import torch
|
3 |
|
4 |
import gradio as gr
|
5 |
from dotenv import load_dotenv
|
6 |
from langchain.callbacks.base import BaseCallbackHandler
|
7 |
from langchain.embeddings import CacheBackedEmbeddings
|
8 |
+
from langchain_community.retrievers import BM25Retriever, EnsembleRetriever
|
9 |
from langchain.storage import LocalFileStore
|
10 |
from langchain_anthropic import ChatAnthropic
|
11 |
from langchain_community.chat_models import ChatOllama
|
|
|
101 |
combined_documents = py_docs + mdx_docs + ipynb_docs
|
102 |
print(f"์ด ๋ํ๋จผํธ ๊ฐ์: {len(combined_documents)}")
|
103 |
|
104 |
+
|
105 |
+
# Define the device setting function
|
106 |
+
def get_device():
|
107 |
+
if torch.cuda.is_available():
|
108 |
+
return "cuda:0"
|
109 |
+
elif torch.backends.mps.is_available():
|
110 |
+
return "mps"
|
111 |
+
else:
|
112 |
+
return "cpu"
|
113 |
+
|
114 |
+
|
115 |
+
# Use the function to set the device in model_kwargs
|
116 |
+
device = get_device()
|
117 |
+
|
118 |
# Initialize embeddings and cache
|
119 |
store = LocalFileStore("~/.cache/embedding")
|
120 |
embeddings = HuggingFaceBgeEmbeddings(
|
121 |
model_name="BAAI/bge-m3",
|
122 |
+
model_kwargs={"device": device},
|
123 |
encode_kwargs={"normalize_embeddings": True},
|
124 |
)
|
125 |
cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
|