heaversm commited on
Commit
449cbf5
·
1 Parent(s): a107c82

initial commit - command line only.

Browse files
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .env
2
- .venv
 
 
1
  .env
2
+ .venv
3
+ **/__pycache__/*
data/db/chroma.sqlite3 ADDED
Binary file (156 kB). View file
 
lib/chain.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from operator import itemgetter
3
+
4
+ from langchain_chroma import Chroma
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
+ from langchain_core.runnables import RunnablePassthrough, RunnableParallel
7
+ from langchain_core.output_parsers import JsonOutputParser
8
+ from langchain.prompts import PromptTemplate
9
+
10
+ from lib.models import MODELS_MAP
11
+ from lib.utils import format_docs, retrieve_answer, load_embeddings
12
+ from lib.entities import LLMEvalResult
13
+
14
+ def create_retriever(llm_name, db_path, docs, collection_name="local-rag"):
15
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=60)
16
+
17
+ splits = text_splitter.split_documents(docs)
18
+
19
+ embeddings = load_embeddings(llm_name)
20
+
21
+ if not os.path.exists(db_path):
22
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=db_path, collection_name=collection_name)
23
+ else:
24
+ vectorstore = Chroma(persist_directory=db_path, embedding_function=embeddings, collection_name=collection_name)
25
+
26
+ retriever = vectorstore.as_retriever()
27
+ return retriever
28
+
29
+ def create_qa_chain(llm, retriever, prompts_text):
30
+ initial_prompt_text = prompts_text["initial_prompt"]
31
+ qa_eval_prompt_text = prompts_text["evaluation_prompt"]
32
+
33
+ initial_prompt = PromptTemplate(
34
+ template=initial_prompt_text,
35
+ input_variables=["question", "context"]
36
+ )
37
+
38
+ json_parser = JsonOutputParser(pydantic_object=LLMEvalResult)
39
+ qa_eval_prompt = PromptTemplate(
40
+ template=qa_eval_prompt_text,
41
+ input_variables=["question","answer"],
42
+ partial_variables={"format_instructions": json_parser.get_format_instructions()},
43
+ )
44
+
45
+ qa_eval_prompt_with_context = PromptTemplate(
46
+ template=qa_eval_prompt_text,
47
+ input_variables=["question","answer","context"],
48
+ partial_variables={"format_instructions": json_parser.get_format_instructions()},
49
+ )
50
+
51
+ chain = (
52
+ RunnableParallel(context = retriever | format_docs, question = RunnablePassthrough()) |
53
+ RunnableParallel(answer = initial_prompt | llm | retrieve_answer, question = itemgetter("question"), context = itemgetter("context") ) |
54
+ RunnableParallel(input = qa_eval_prompt, context = itemgetter("context"), answer = itemgetter("answer")) |
55
+ RunnableParallel(evaluation = itemgetter("input") | llm , context = itemgetter("context"), answer = itemgetter("answer") ) |
56
+ RunnableParallel(output = itemgetter("answer"), evaluation = itemgetter("evaluation") | json_parser, context = itemgetter("context"))
57
+ )
58
+ return chain
lib/entities.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
+
4
+ class AccuracyEnum(str, Enum):
5
+ accurate = "accurate"
6
+ inaccurate = "inaccurate"
7
+
8
+ class LLMEvalResult(BaseModel):
9
+ accuracy: AccuracyEnum = Field(description="Label indicating if the answer is accurate or inaccurate.")
10
+ feedback: str = Field(description="Explanation of why the specific label was assigned. Must be concise and not more than 2 sentences.")
lib/loader.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_community.document_loaders.generic import GenericLoader
3
+ from langchain_community.document_loaders.parsers import LanguageParser
4
+ from langchain_text_splitters import Language
5
+
6
+ def load_files(repository_path):
7
+ loader = GenericLoader.from_filesystem(
8
+ repository_path,
9
+ glob="**/*",
10
+ suffixes=[".py"],
11
+ parser=LanguageParser(
12
+ language=Language.PYTHON
13
+ )
14
+ )
15
+
16
+ docs = loader.load()
17
+
18
+ return docs
lib/models.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain_openai import OpenAI
4
+ from langchain_groq import ChatGroq
5
+ from langchain_openai import OpenAIEmbeddings
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+
8
+ load_dotenv()
9
+
10
+ MODELS_MAP = {
11
+ "OpenAI gpt-4o": {
12
+ "class": OpenAI,
13
+ "params": {
14
+ "temperature": 0,
15
+ "api_key": os.getenv("OPENAI_API_KEY")
16
+ },
17
+ "embedding_class": OpenAIEmbeddings,
18
+ "embedding_params": {
19
+ "api_key": os.getenv("OPENAI_API_KEY")
20
+ }
21
+ },
22
+ "Groq LLaMA3 70b": {
23
+ "class": ChatGroq,
24
+ "params": {
25
+ "model_name": "llama3-70b-8192",
26
+ "groq_api_key": os.getenv("GROQ_API_KEY")
27
+ },
28
+ "embedding_class": HuggingFaceEmbeddings,
29
+ "embedding_params": {
30
+ "model_name": "sentence-transformers/all-MiniLM-L6-v2"
31
+ }
32
+ },
33
+ "Groq Mixtral 8x7b": {
34
+ "class": ChatGroq,
35
+ "params": {
36
+ "model_name": "mixtral-8x7b-32768",
37
+ "groq_api_key": os.getenv("GROQ_API_KEY")
38
+ },
39
+ "embedding_class": HuggingFaceEmbeddings,
40
+ "embedding_params": {
41
+ "model_name": "sentence-transformers/all-MiniLM-L6-v2"
42
+ }
43
+ }
44
+ }
lib/repository.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import git
3
+
4
+ def download_github_repo(repo_url, repo_dir):
5
+ if os.path.exists(repo_dir):
6
+ print(f"Repository {repo_dir} already exists. Pulling latest changes.")
7
+ repo = git.Repo(repo_dir)
8
+ repo.remotes.origin.pull()
9
+ else:
10
+ print(f"Cloning repository from {repo_url} to {repo_dir}.")
11
+ git.Repo.clone_from(repo_url, repo_dir)
lib/utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib.models import MODELS_MAP
2
+
3
+ def read_prompt(file_name):
4
+ with open(file_name, 'r') as file:
5
+ return file.read()
6
+
7
+ def format_docs(docs):
8
+ return "\n\n".join(doc.page_content for doc in docs)
9
+
10
+ def retrieve_answer(output):
11
+ # print(f"Output: {output}")
12
+ # return output.content
13
+ return output
14
+
15
+ def load_LLM(llm_name):
16
+ model_config = MODELS_MAP[llm_name]
17
+ model_class = model_config["class"]
18
+ params = model_config["params"]
19
+ llm = model_class(**params)
20
+ return llm
21
+
22
+ def load_embeddings(llm_name):
23
+ model_config = MODELS_MAP[llm_name]
24
+ embedding_class = model_config["embedding_class"]
25
+ embedding_params = model_config["embedding_params"]
26
+ embeddings = embedding_class(**embedding_params)
27
+ return embeddings
28
+
29
+ def get_available_models():
30
+ return list(MODELS_MAP.keys())
31
+
32
+ def select_model():
33
+ models = get_available_models()
34
+ print("Available Models:")
35
+ for i, model in enumerate(models):
36
+ print(f"{i + 1}. {model}")
37
+
38
+ while True:
39
+ try:
40
+ choice = int(input("Select a model by number: ")) - 1
41
+ if 0 <= choice < len(models):
42
+ return models[choice]
43
+ else:
44
+ print("Invalid choice. Please select a number from the list.")
45
+ except ValueError:
46
+ print("Invalid input. Please enter a number.")
main.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ from langchain.globals import set_debug
6
+ from langchain_core.runnables import RunnablePassthrough
7
+ from langchain_core.output_parsers import StrOutputParser
8
+
9
+ from lib.repository import download_github_repo
10
+ from lib.loader import load_files
11
+ from lib.chain import create_retriever, create_qa_chain
12
+ from lib.utils import read_prompt, load_LLM, select_model
13
+ from lib.models import MODELS_MAP
14
+
15
+ # set_debug(True)
16
+
17
+ def main():
18
+ # Prompt user to select the model
19
+ model_name = select_model()
20
+ model_info = MODELS_MAP[model_name]
21
+
22
+ # Parse the command line arguments
23
+ parser = argparse.ArgumentParser(description="GitHub Repo QA CLI Application")
24
+ parser.add_argument("repo_url", type=str, help="URL of the GitHub repository")
25
+ args = parser.parse_args()
26
+
27
+ # Extract the repository name from the URL
28
+ repo_url = args.repo_url
29
+ repo_name = repo_url.split("/")[-1].replace(".git", "")
30
+
31
+ # Compute the path to the data folder relative to the script's directory
32
+ base_dir = os.path.dirname(os.path.abspath(__file__))
33
+ repo_dir = os.path.join(base_dir, "data", repo_name)
34
+ db_dir = os.path.join(base_dir, "data", "db")
35
+ prompt_templates_dir = os.path.join(base_dir, "prompt_templates")
36
+
37
+ # Download the GitHub repository
38
+ print(f"Downloading repository from {repo_url}...")
39
+ download_github_repo(repo_url, repo_dir)
40
+
41
+ # Load prompt templates
42
+ prompts_text = {
43
+ "initial_prompt": read_prompt(os.path.join(prompt_templates_dir, 'initial_prompt.txt')),
44
+ "evaluation_prompt": read_prompt(os.path.join(prompt_templates_dir, 'evaluation_prompt.txt')),
45
+ }
46
+
47
+ # Load documents from the repository
48
+ print(f"Loading documents from {repo_dir}...")
49
+ document_chunks = load_files(repository_path=repo_dir)
50
+ print(f"Created chunks length is: {len(document_chunks)}")
51
+
52
+ # Create model, retriever
53
+ print(f"Creating retrieval QA chain using {model_name}...")
54
+ llm = load_LLM(model_name)
55
+ retriever = create_retriever(model_name, db_dir, document_chunks)
56
+ qa_chain = create_qa_chain(llm, retriever, prompts_text)
57
+
58
+ print("You can start asking questions. Type 'exit' to quit.")
59
+ while True:
60
+ question = input("Question: ")
61
+ if question.lower() == "exit":
62
+ break
63
+ answer = qa_chain.invoke(question)
64
+ print(f"Answer: {answer['output']}")
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
prompt_templates/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ `python -m venv .venv`
2
+ `source .venv/bin/activate`
3
+ `pip3 install -r requirements.txt`
4
+ `python3 main.py https://github.com/streamlit/streamlit`
prompt_templates/evaluation_prompt.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ You are a technical assessor reviewing a test. You are provided with a question along with an answer for the question written by a developer. Evaluate the question-answer pair and provide feedback.
2
+ {format_instructions}
3
+ Question: {question}
4
+ Answer: {answer}
prompt_templates/initial_prompt.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ You are an assistant for question-answering tasks in the software engineering field. Use the following pieces of retrieved context from the provided GitHub repository to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. If applicable, include a brief code snippet to illustrate your answer.
2
+ Question: {question}
3
+ Repository Context: {context}
4
+ Answer:
requirements.txt CHANGED
@@ -15,4 +15,6 @@ langchain-text-splitters
15
  esprima
16
  tree_sitter
17
  tree_sitter_languages
18
- pysqlite3
 
 
 
15
  esprima
16
  tree_sitter
17
  tree_sitter_languages
18
+ pysqlite3-binary
19
+ git
20
+ gradio