heaversm commited on
Commit
ab2d07e
·
1 Parent(s): 449cbf5

add frontend in gradio

Browse files
Files changed (4) hide show
  1. lib/utils.py +10 -9
  2. main.py +37 -19
  3. prompt_templates/README.md +3 -1
  4. requirements.txt +2 -1
lib/utils.py CHANGED
@@ -35,12 +35,13 @@ def select_model():
35
  for i, model in enumerate(models):
36
  print(f"{i + 1}. {model}")
37
 
38
- while True:
39
- try:
40
- choice = int(input("Select a model by number: ")) - 1
41
- if 0 <= choice < len(models):
42
- return models[choice]
43
- else:
44
- print("Invalid choice. Please select a number from the list.")
45
- except ValueError:
46
- print("Invalid input. Please enter a number.")
 
 
35
  for i, model in enumerate(models):
36
  print(f"{i + 1}. {model}")
37
 
38
+ return models[0] #MH - temp, just use openAI
39
+ # while True:
40
+ # try:
41
+ # choice = int(input("Select a model by number: ")) - 1
42
+ # if 0 <= choice < len(models):
43
+ # return models[choice]
44
+ # else:
45
+ # print("Invalid choice. Please select a number from the list.")
46
+ # except ValueError:
47
+ # print("Invalid input. Please enter a number.")
main.py CHANGED
@@ -12,20 +12,44 @@ from lib.chain import create_retriever, create_qa_chain
12
  from lib.utils import read_prompt, load_LLM, select_model
13
  from lib.models import MODELS_MAP
14
 
 
 
 
 
 
 
 
 
15
  # set_debug(True)
16
 
17
- def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Prompt user to select the model
19
  model_name = select_model()
20
  model_info = MODELS_MAP[model_name]
21
-
22
- # Parse the command line arguments
23
- parser = argparse.ArgumentParser(description="GitHub Repo QA CLI Application")
24
- parser.add_argument("repo_url", type=str, help="URL of the GitHub repository")
25
- args = parser.parse_args()
26
-
27
- # Extract the repository name from the URL
28
- repo_url = args.repo_url
29
  repo_name = repo_url.split("/")[-1].replace(".git", "")
30
 
31
  # Compute the path to the data folder relative to the script's directory
@@ -53,16 +77,10 @@ def main():
53
  print(f"Creating retrieval QA chain using {model_name}...")
54
  llm = load_LLM(model_name)
55
  retriever = create_retriever(model_name, db_dir, document_chunks)
 
56
  qa_chain = create_qa_chain(llm, retriever, prompts_text)
57
-
58
- print("You can start asking questions. Type 'exit' to quit.")
59
- while True:
60
- question = input("Question: ")
61
- if question.lower() == "exit":
62
- break
63
- answer = qa_chain.invoke(question)
64
- print(f"Answer: {answer['output']}")
65
-
66
 
67
  if __name__ == "__main__":
68
- main()
 
12
  from lib.utils import read_prompt, load_LLM, select_model
13
  from lib.models import MODELS_MAP
14
 
15
+ import time
16
+ import gradio as gr
17
+
18
+ def slow_echo(message, history):
19
+ for i in range(len(message)):
20
+ time.sleep(0.05)
21
+ yield message[: i + 1]
22
+
23
  # set_debug(True)
24
 
25
+ def build():
26
+ with gr.Blocks() as demo:
27
+ repo_url = gr.Textbox(label="Repo URL", placeholder="Enter the repository URL here...")
28
+ submit_btn = gr.Button("Submit Repo URL")
29
+
30
+ user_input = gr.Textbox(label="User Input", placeholder="Enter your question here...")
31
+ chat_output = gr.Textbox(label="Chat Output", placeholder="The answer will appear here...")
32
+ # add a status textbox
33
+
34
+ def update_repo_url(new_url):
35
+ updated_url = main(new_url)
36
+ return updated_url
37
+
38
+ def generate_answer(user_input):
39
+ answer = qa_chain.invoke(user_input)
40
+ print(f"Answer: {answer}")
41
+ return answer['output']
42
+
43
+ submit_btn.click(update_repo_url, inputs=repo_url, outputs=repo_url)
44
+ user_input_submit_btn = gr.Button("Submit Question")
45
+ user_input_submit_btn.click(generate_answer, inputs=user_input, outputs=chat_output)
46
+
47
+ demo.launch()
48
+
49
+ def main(repo_url):
50
  # Prompt user to select the model
51
  model_name = select_model()
52
  model_info = MODELS_MAP[model_name]
 
 
 
 
 
 
 
 
53
  repo_name = repo_url.split("/")[-1].replace(".git", "")
54
 
55
  # Compute the path to the data folder relative to the script's directory
 
77
  print(f"Creating retrieval QA chain using {model_name}...")
78
  llm = load_LLM(model_name)
79
  retriever = create_retriever(model_name, db_dir, document_chunks)
80
+ global qa_chain
81
  qa_chain = create_qa_chain(llm, retriever, prompts_text)
82
+ print(f"Ready to chat!")
83
+ return repo_url
 
 
 
 
 
 
 
84
 
85
  if __name__ == "__main__":
86
+ build()
prompt_templates/README.md CHANGED
@@ -1,4 +1,6 @@
1
  `python -m venv .venv`
2
  `source .venv/bin/activate`
3
  `pip3 install -r requirements.txt`
4
- `python3 main.py https://github.com/streamlit/streamlit`
 
 
 
1
  `python -m venv .venv`
2
  `source .venv/bin/activate`
3
  `pip3 install -r requirements.txt`
4
+ `python3 main.py https://github.com/streamlit/streamlit`
5
+
6
+ where the URL is the repo you want to ask questions about
requirements.txt CHANGED
@@ -17,4 +17,5 @@ tree_sitter
17
  tree_sitter_languages
18
  pysqlite3-binary
19
  git
20
- gradio
 
 
17
  tree_sitter_languages
18
  pysqlite3-binary
19
  git
20
+ gradio
21
+ time