Codebert-Repo-Analyzer / search-pickle.py
heaversm's picture
build a github similarity score retriever - no streamlit integration yet
6c5b95d
raw
history blame
3.2 kB
import streamlit as st
from bs4 import BeautifulSoup
from langchain.embeddings import HuggingFaceEmbeddings
import pickle
import torch
import io
from langchain.vectorstores import FAISS
import json
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
else: return super().find_class(module, name)
@st.cache_resource
def get_hugging_face_model():
model_name = "mchochlov/codebert-base-cd-ft"
hf = HuggingFaceEmbeddings(model_name=model_name)
return hf
@st.cache_resource
def get_db():
with open("codesearchdb.pickle", "rb") as f:
db = CPU_Unpickler(f).load()
print("Loaded db")
# save_as_json(db, "codesearchdb.json") # Save as JSON
return db
def save_as_json(data, filename):
# Convert the data to a JSON serializable format
serializable_data = data_to_serializable(data)
with open(filename, "w") as json_file:
json.dump(serializable_data, json_file)
def data_to_serializable(data):
if isinstance(data, dict):
return {k: data_to_serializable(v) for k, v in data.items() if not callable(v) and not isinstance(v, type)}
elif isinstance(data, list):
return [data_to_serializable(item) for item in data]
elif isinstance(data, (str, int, float, bool)) or data is None:
return data
elif hasattr(data, '__dict__'):
return data_to_serializable(data.__dict__)
elif hasattr(data, '__slots__'):
return {slot: data_to_serializable(getattr(data, slot)) for slot in data.__slots__}
else:
return str(data) # Convert any other types to string
def get_similar_links(query, db, embeddings):
embedding_vector = embeddings.embed_query(query)
docs_and_scores = db.similarity_search_by_vector(embedding_vector, k = 10)
hrefs = []
for docs in docs_and_scores:
html_doc = docs.page_content
soup = BeautifulSoup(html_doc, 'html.parser')
href = [a['href'] for a in soup.find_all('a', href=True)]
hrefs.append(href)
links = []
for href_list in hrefs:
for link in href_list:
links.append(link)
return links
embedding_vector = get_hugging_face_model()
db = FAISS.load_local("code_sim_index", embedding_vector, allow_dangerous_deserialization=True)
save_as_json(db, "code_sim_index.json") # Save as JSON
st.title("Find Similar Code")
text_input = st.text_area("Enter a Code Example", value =
"""
class Solution:
def subsets(self, nums: List[int]) -> List[List[int]]:
outputs = []
def backtrack(k, index, subSet):
if index == k:
outputs.append(subSet[:])
return
for i in range(index, len(nums)):
backtrack(k, i + 1, subSet + [nums[i]])
for j in range(len(nums) + 1):
backtrack(j, 0, [])
return outputs
""", height = 330
)
button = st.button("Find Similar Questions")
if button:
query = text_input
answer = get_similar_links(query, db, embedding_vector)
for link in set(answer):
st.write(link)
else:
st.info("Please Input Valid Text")
# get_db()