ariG23498 HF staff commited on
Commit
4d32fbd
1 Parent(s): 260ac27

adding lance

Browse files
Files changed (2) hide show
  1. app.py +11 -10
  2. requirements.txt +2 -1
app.py CHANGED
@@ -5,28 +5,29 @@ from datasets import load_dataset
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import os
 
8
 
9
  os.environ["HF_TOKEN"] = os.getenv("auth")
10
- dataset = load_dataset("ariG23498/pis-blogs-chunked")
 
 
 
11
  embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device="cuda")
12
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
13
  model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16, device_map="auto")
14
 
15
  @spaces.GPU(duration=300)
16
  def process_query(query):
17
- text_embeddings = embedding_model.encode(dataset["train"]["text"])
18
  query_embedding = embedding_model.encode(query)
 
19
 
20
- similarity_scores = embedding_model.similarity(query_embedding, text_embeddings)
21
- top_indices = (-similarity_scores).argsort()[0][:5]
22
-
23
- context = dataset["train"]["text"][top_indices[0]]
24
- url = dataset["train"]["url"][top_indices[0]]
25
 
26
  input_text = (
27
- f"Based on the context provided, '{context}', how would"
28
- f"you address the user's query regarding '{query}'? Please"
29
- " provide a detailed and contextually relevant response."
30
  )
31
 
32
  input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
 
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import os
8
+ import lancedb
9
 
10
  os.environ["HF_TOKEN"] = os.getenv("auth")
11
+
12
+ db = lancedb.connect("embedding_dataset")
13
+ tbl = db.open_table("my_table")
14
+
15
  embedding_model = SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device="cuda")
16
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
17
  model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", torch_dtype=torch.bfloat16, device_map="auto")
18
 
19
  @spaces.GPU(duration=300)
20
  def process_query(query):
 
21
  query_embedding = embedding_model.encode(query)
22
+ search_hits = tbl.search(query_embedding).metric("cosine").limit(5).to_list()
23
 
24
+ context = search_hits[0]["text"]
25
+ url = search_hits[0]["url"]
 
 
 
26
 
27
  input_text = (
28
+ f"You are being provided a query: {query}"
29
+ f"YOu are being provided context to the query: {context}"
30
+ "Please provide a detailed and contextually relevant response."
31
  )
32
 
33
  input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
requirements.txt CHANGED
@@ -2,4 +2,5 @@ accelerate
2
  transformers
3
  gradio
4
  sentence-transformers
5
- datasets
 
 
2
  transformers
3
  gradio
4
  sentence-transformers
5
+ datasets
6
+ lancedb