heaversm commited on
Commit
6c5b95d
·
1 Parent(s): d2115ee

build a github similarity score retriever - no streamlit integration yet

Browse files
Files changed (7) hide show
  1. .gitignore +3 -0
  2. README.md +6 -0
  3. app.py +24 -0
  4. github.py +67 -0
  5. requirements.txt +3 -1
  6. search-pickle.py +99 -0
  7. unpickle.py +19 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ .venv/
3
+ lib/
README.md CHANGED
@@ -11,3 +11,9 @@ license: mit
11
  ---
12
 
13
 
 
 
 
 
 
 
 
11
  ---
12
 
13
 
14
+ ## Local Dev
15
+
16
+ `python -m venv .venv`
17
+ `source .venv/bin/activate/`
18
+ `pip install -r requirements.txt`
19
+ `streamlit run app.py`
app.py CHANGED
@@ -5,6 +5,7 @@ import pickle
5
  import torch
6
  import io
7
  from langchain.vectorstores import FAISS
 
8
 
9
  class CPU_Unpickler(pickle.Unpickler):
10
  def find_class(self, module, name):
@@ -24,8 +25,29 @@ def get_hugging_face_model():
24
  def get_db():
25
  with open("codesearchdb.pickle", "rb") as f:
26
  db = CPU_Unpickler(f).load()
 
 
27
  return db
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def get_similar_links(query, db, embeddings):
31
  embedding_vector = embeddings.embed_query(query)
@@ -45,6 +67,7 @@ def get_similar_links(query, db, embeddings):
45
 
46
  embedding_vector = get_hugging_face_model()
47
  db = FAISS.load_local("code_sim_index", embedding_vector, allow_dangerous_deserialization=True)
 
48
 
49
  st.title("Find Similar Code")
50
  text_input = st.text_area("Enter a Code Example", value =
@@ -73,3 +96,4 @@ if button:
73
  else:
74
  st.info("Please Input Valid Text")
75
 
 
 
5
  import torch
6
  import io
7
  from langchain.vectorstores import FAISS
8
+ import json
9
 
10
  class CPU_Unpickler(pickle.Unpickler):
11
  def find_class(self, module, name):
 
25
  def get_db():
26
  with open("codesearchdb.pickle", "rb") as f:
27
  db = CPU_Unpickler(f).load()
28
+ print("Loaded db")
29
+ # save_as_json(db, "codesearchdb.json") # Save as JSON
30
  return db
31
 
32
+ def save_as_json(data, filename):
33
+ # Convert the data to a JSON serializable format
34
+ serializable_data = data_to_serializable(data)
35
+ with open(filename, "w") as json_file:
36
+ json.dump(serializable_data, json_file)
37
+
38
+ def data_to_serializable(data):
39
+ if isinstance(data, dict):
40
+ return {k: data_to_serializable(v) for k, v in data.items() if not callable(v) and not isinstance(v, type)}
41
+ elif isinstance(data, list):
42
+ return [data_to_serializable(item) for item in data]
43
+ elif isinstance(data, (str, int, float, bool)) or data is None:
44
+ return data
45
+ elif hasattr(data, '__dict__'):
46
+ return data_to_serializable(data.__dict__)
47
+ elif hasattr(data, '__slots__'):
48
+ return {slot: data_to_serializable(getattr(data, slot)) for slot in data.__slots__}
49
+ else:
50
+ return str(data) # Convert any other types to string
51
 
52
  def get_similar_links(query, db, embeddings):
53
  embedding_vector = embeddings.embed_query(query)
 
67
 
68
  embedding_vector = get_hugging_face_model()
69
  db = FAISS.load_local("code_sim_index", embedding_vector, allow_dangerous_deserialization=True)
70
+ save_as_json(db, "code_sim_index.json") # Save as JSON
71
 
72
  st.title("Find Similar Code")
73
  text_input = st.text_area("Enter a Code Example", value =
 
96
  else:
97
  st.info("Please Input Valid Text")
98
 
99
+ # get_db()
github.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain.document_loaders import GithubFileLoader
4
+ # from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_text_splitters import CharacterTextSplitter
8
+
9
+ load_dotenv()
10
+
11
+ #get the GITHUB_ACCESS_TOKEN from the .env file
12
+ GITHUB_ACCESS_TOKEN = os.getenv("GITHUB_ACCESS_TOKEN")
13
+ USER = "heaversm"
14
+ REPO = "gdrive-docker"
15
+ GITHUB_BASE_URL = "https://github.com/"
16
+
17
+
18
+ def get_similar_files(query, db, embeddings):
19
+ # embedding_vector = embeddings.embed_query(query)
20
+ # docs_and_scores = db.similarity_search_by_vector(embedding_vector, k = 10)
21
+ docs_and_scores = db.similarity_search_with_score(query)
22
+ return docs_and_scores
23
+
24
+ def get_hugging_face_model():
25
+ model_name = "mchochlov/codebert-base-cd-ft"
26
+ hf = HuggingFaceEmbeddings(model_name=model_name)
27
+ return hf
28
+
29
+ loader = GithubFileLoader(
30
+ #repo is USER/REPO
31
+ repo=f"{USER}/{REPO}",
32
+ access_token=GITHUB_ACCESS_TOKEN,
33
+ github_api_url="https://api.github.com",
34
+ file_filter=lambda file_path: file_path.endswith(
35
+ (".py", ".ts")
36
+ ), # load all python and typescript files
37
+ )
38
+ documents = loader.load()
39
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
40
+ docs = text_splitter.split_documents(documents)
41
+ embedding_vector = get_hugging_face_model()
42
+ db = FAISS.from_documents(docs, embedding_vector)
43
+ model_name = "mchochlov/codebert-base-cd-ft"
44
+
45
+ query = """
46
+ def create_app():
47
+ app = connexion.FlaskApp(__name__, specification_dir="../.openapi")
48
+ app.add_api(
49
+ API_VERSION, resolver=connexion.resolver.RelativeResolver("provider.app")
50
+ )
51
+ """
52
+ results_with_scores = get_similar_files(query, db, embedding_vector)
53
+ print ("retrieved!!!")
54
+ print(f"Number of results: {len(results_with_scores)}")
55
+
56
+ # score is a distance score, the lower the better
57
+ for doc, score in results_with_scores:
58
+ print(f"Metadata: {doc.metadata}, Score: {score}")
59
+
60
+ top_file_path = results_with_scores[0][0].metadata['path']
61
+ top_file_content = results_with_scores[0][0].page_content
62
+ top_file_score = results_with_scores[0][1]
63
+ top_file_link = f"{GITHUB_BASE_URL}{USER}/{REPO}/blob/main/{top_file_path}"
64
+
65
+ print(f"Top file link: {top_file_link}")
66
+
67
+
requirements.txt CHANGED
@@ -3,4 +3,6 @@ sentence-transformers
3
  bs4
4
  faiss-cpu
5
  altair==4.0
6
- langchain-community
 
 
 
3
  bs4
4
  faiss-cpu
5
  altair==4.0
6
+ langchain-community
7
+ streamlit
8
+ python-dotenv
search-pickle.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from bs4 import BeautifulSoup
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ import pickle
5
+ import torch
6
+ import io
7
+ from langchain.vectorstores import FAISS
8
+ import json
9
+
10
+ class CPU_Unpickler(pickle.Unpickler):
11
+ def find_class(self, module, name):
12
+ if module == 'torch.storage' and name == '_load_from_bytes':
13
+ return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
14
+ else: return super().find_class(module, name)
15
+
16
+
17
+ @st.cache_resource
18
+ def get_hugging_face_model():
19
+ model_name = "mchochlov/codebert-base-cd-ft"
20
+ hf = HuggingFaceEmbeddings(model_name=model_name)
21
+ return hf
22
+
23
+
24
+ @st.cache_resource
25
+ def get_db():
26
+ with open("codesearchdb.pickle", "rb") as f:
27
+ db = CPU_Unpickler(f).load()
28
+ print("Loaded db")
29
+ # save_as_json(db, "codesearchdb.json") # Save as JSON
30
+ return db
31
+
32
+ def save_as_json(data, filename):
33
+ # Convert the data to a JSON serializable format
34
+ serializable_data = data_to_serializable(data)
35
+ with open(filename, "w") as json_file:
36
+ json.dump(serializable_data, json_file)
37
+
38
+ def data_to_serializable(data):
39
+ if isinstance(data, dict):
40
+ return {k: data_to_serializable(v) for k, v in data.items() if not callable(v) and not isinstance(v, type)}
41
+ elif isinstance(data, list):
42
+ return [data_to_serializable(item) for item in data]
43
+ elif isinstance(data, (str, int, float, bool)) or data is None:
44
+ return data
45
+ elif hasattr(data, '__dict__'):
46
+ return data_to_serializable(data.__dict__)
47
+ elif hasattr(data, '__slots__'):
48
+ return {slot: data_to_serializable(getattr(data, slot)) for slot in data.__slots__}
49
+ else:
50
+ return str(data) # Convert any other types to string
51
+
52
+ def get_similar_links(query, db, embeddings):
53
+ embedding_vector = embeddings.embed_query(query)
54
+ docs_and_scores = db.similarity_search_by_vector(embedding_vector, k = 10)
55
+ hrefs = []
56
+ for docs in docs_and_scores:
57
+ html_doc = docs.page_content
58
+ soup = BeautifulSoup(html_doc, 'html.parser')
59
+ href = [a['href'] for a in soup.find_all('a', href=True)]
60
+ hrefs.append(href)
61
+ links = []
62
+ for href_list in hrefs:
63
+ for link in href_list:
64
+ links.append(link)
65
+ return links
66
+
67
+
68
+ embedding_vector = get_hugging_face_model()
69
+ db = FAISS.load_local("code_sim_index", embedding_vector, allow_dangerous_deserialization=True)
70
+ save_as_json(db, "code_sim_index.json") # Save as JSON
71
+
72
+ st.title("Find Similar Code")
73
+ text_input = st.text_area("Enter a Code Example", value =
74
+ """
75
+ class Solution:
76
+ def subsets(self, nums: List[int]) -> List[List[int]]:
77
+ outputs = []
78
+ def backtrack(k, index, subSet):
79
+ if index == k:
80
+ outputs.append(subSet[:])
81
+ return
82
+ for i in range(index, len(nums)):
83
+ backtrack(k, i + 1, subSet + [nums[i]])
84
+ for j in range(len(nums) + 1):
85
+ backtrack(j, 0, [])
86
+ return outputs
87
+ """, height = 330
88
+ )
89
+ button = st.button("Find Similar Questions")
90
+ if button:
91
+ query = text_input
92
+ answer = get_similar_links(query, db, embedding_vector)
93
+ for link in set(answer):
94
+ st.write(link)
95
+
96
+ else:
97
+ st.info("Please Input Valid Text")
98
+
99
+ # get_db()
unpickle.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ # Define the path to the pickle file
4
+ pickle_file_path = 'codesearchdb.pickle'
5
+
6
+ # Load the pickle file
7
+ with open(pickle_file_path, 'rb') as file:
8
+ data = pickle.load(file)
9
+
10
+
11
+
12
+ # Save the contents to a new file (for example, a JSON file)
13
+ import json
14
+
15
+ json_file_path = 'codesearchdb.json'
16
+ with open(json_file_path, 'w') as json_file:
17
+ json.dump(data, json_file, indent=4)
18
+
19
+ print(f"Contents have been saved to {json_file_path}")