zamalali commited on
Commit
94ed277
·
1 Parent(s): 27298b1

Refactor app.py and main.py for improved readability and functionality; add environment variable loading

Browse files
Files changed (3) hide show
  1. __pycache__/main.cpython-311.pyc +0 -0
  2. app.py +4 -8
  3. main.py +33 -20
__pycache__/main.cpython-311.pyc ADDED
Binary file (21.6 kB). View file
 
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import spaces
2
-
3
-
4
  import gradio as gr
5
  import time
6
- from gradio.themes.utils import sizes
7
  import threading
8
  import logging
9
- from main import run_repository_ranking # Your repository ranking function
 
 
10
  # ---------------------------
11
  # Global Logging Buffer Setup
12
  # ---------------------------
@@ -45,8 +44,7 @@ def parse_result_to_html(raw_result: str) -> str:
45
  Only the top 10 results are displayed.
46
  """
47
  entries = raw_result.strip().split("Final Rank:")
48
- # Only use the first 10 entries (if available)
49
- entries = entries[1:11]
50
  if not entries:
51
  return "<p>No repositories found for your query.</p>"
52
  html = """
@@ -163,8 +161,6 @@ with gr.Blocks(
163
  elem_id="header"
164
  )
165
 
166
-
167
- # Centered main container for inputs and outputs.
168
  with gr.Column(elem_id="main-container"):
169
  research_input = gr.Textbox(
170
  label="Research Query",
 
1
  import spaces
 
 
2
  import gradio as gr
3
  import time
 
4
  import threading
5
  import logging
6
+ from gradio.themes.utils import sizes
7
+ from main import run_repository_ranking # Import the repository ranking function
8
+
9
  # ---------------------------
10
  # Global Logging Buffer Setup
11
  # ---------------------------
 
44
  Only the top 10 results are displayed.
45
  """
46
  entries = raw_result.strip().split("Final Rank:")
47
+ entries = entries[1:11] # Use only the first 10 entries
 
48
  if not entries:
49
  return "<p>No repositories found for your query.</p>"
50
  html = """
 
161
  elem_id="header"
162
  )
163
 
 
 
164
  with gr.Column(elem_id="main-container"):
165
  research_input = gr.Textbox(
166
  label="Research Query",
main.py CHANGED
@@ -6,7 +6,12 @@ import faiss
6
  import re
7
  import logging
8
  from pathlib import Path
 
 
 
9
  from dotenv import load_dotenv
 
 
10
  from sentence_transformers import SentenceTransformer, CrossEncoder
11
  from langchain_groq import ChatGroq
12
  from langchain_core.prompts import ChatPromptTemplate
@@ -18,14 +23,21 @@ except ImportError:
18
  BM25Okapi = None
19
 
20
  # ---------------------------
21
- # Environment Setup
22
  # ---------------------------
23
- load_dotenv()
 
 
 
 
 
 
24
  CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
25
- # Setup a persistent session for GitHub API requests.
 
26
  session = requests.Session()
27
  session.headers.update({
28
- "Authorization": f"token {os.getenv('GITHUB_API_KEY')}",
29
  "Accept": "application/vnd.github.v3+json"
30
  })
31
 
@@ -37,7 +49,9 @@ llm = ChatGroq(
37
  temperature=0.3,
38
  max_tokens=512,
39
  max_retries=3,
 
40
  )
 
41
  prompt = ChatPromptTemplate.from_messages([
42
  ("system",
43
  """You are a GitHub search optimization expert.
@@ -115,7 +129,7 @@ def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str
115
  # ---------------------------
116
  # GitHub API Helper Functions
117
  # ---------------------------
118
- def fetch_readme_content(repo_full_name):
119
  readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
120
  response = session.get(readme_url)
121
  if response.status_code == 200:
@@ -126,7 +140,7 @@ def fetch_readme_content(repo_full_name):
126
  return ""
127
  return ""
128
 
129
- def fetch_markdown_contents(repo_full_name):
130
  url = f"https://api.github.com/repos/{repo_full_name}/contents"
131
  response = session.get(url)
132
  contents = ""
@@ -141,12 +155,12 @@ def fetch_markdown_contents(repo_full_name):
141
  contents += "\n" + file_resp.text
142
  return contents
143
 
144
- def fetch_all_markdown(repo_full_name):
145
  readme = fetch_readme_content(repo_full_name)
146
  other_md = fetch_markdown_contents(repo_full_name)
147
  return readme + "\n" + other_md
148
 
149
- def fetch_github_repositories(query, max_results=10):
150
  url = "https://api.github.com/search/repositories"
151
  params = {
152
  "q": query,
@@ -173,12 +187,13 @@ def fetch_github_repositories(query, max_results=10):
173
  # Dense Retrieval Model Setup
174
  # ---------------------------
175
  try:
 
176
  model = SentenceTransformer('all-mpnet-base-v2', device='cuda')
177
  except Exception as e:
178
  print("Error initializing GPU for SentenceTransformer; falling back to CPU:", e)
179
  model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
180
-
181
- def robust_min_max_norm(scores):
182
  min_val = scores.min()
183
  max_val = scores.max()
184
  if max_val - min_val < 1e-10:
@@ -188,17 +203,18 @@ def robust_min_max_norm(scores):
188
  # ---------------------------
189
  # Cross-Encoder Re-Ranking Function
190
  # ---------------------------
191
- def cross_encoder_rerank_candidates(candidates, query, model_name, top_n=10):
192
  try:
193
  cross_encoder = CrossEncoder(model_name, device='cuda')
194
  except Exception as e:
195
  print("Error initializing CrossEncoder on GPU; falling back to CPU:", e)
196
- cross_encoder = CrossEncoder(model_name, device='cpu')
 
197
  CHUNK_SIZE = 2000
198
  MAX_DOC_LENGTH = 5000
199
  MIN_DOC_LENGTH = 200
200
 
201
- def split_text(text, chunk_size=CHUNK_SIZE):
202
  return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
203
 
204
  for candidate in candidates:
@@ -213,13 +229,13 @@ def cross_encoder_rerank_candidates(candidates, query, model_name, top_n=10):
213
  chunks = split_text(doc)
214
  pairs = [[query, chunk] for chunk in chunks]
215
  scores = cross_encoder.predict(pairs)
216
- max_score = np.max(scores) if len(scores) > 0 else 0.0
217
- avg_score = np.mean(scores) if len(scores) > 0 else 0.0
218
  candidate["cross_encoder_score"] = float(0.5 * max_score + 0.5 * avg_score)
219
  except Exception as e:
220
  logging.error(f"Error scoring candidate {candidate.get('link', 'unknown')}: {e}")
221
  candidate["cross_encoder_score"] = 0.0
222
-
223
  all_scores = [candidate["cross_encoder_score"] for candidate in candidates]
224
  if all_scores:
225
  min_score = min(all_scores)
@@ -227,7 +243,6 @@ def cross_encoder_rerank_candidates(candidates, query, model_name, top_n=10):
227
  for candidate in candidates:
228
  candidate["cross_encoder_score"] += -min_score
229
 
230
- # Do not sort solely by cross-encoder score; we want to combine metrics.
231
  return candidates
232
 
233
  # ---------------------------
@@ -318,11 +333,9 @@ def run_repository_ranking(query: str) -> str:
318
 
319
  # Step 9: Compute cross-encoder scores for the top candidates.
320
  top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories
321
- # Update candidates with cross-encoder scores.
322
  cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=len(top_candidates))
323
 
324
- # Now combine both metrics: final_score = w1 * combined_score + w2 * cross_encoder_score.
325
- # Adjust weights as needed (here 0.7 for combined, 0.3 for cross-encoder).
326
  w1 = 0.7
327
  w2 = 0.3
328
  for candidate in top_candidates:
 
6
  import re
7
  import logging
8
  from pathlib import Path
9
+
10
+ # For local development, load environment variables from a .env file.
11
+ # In HuggingFace Spaces, secrets are automatically available as environment variables.
12
  from dotenv import load_dotenv
13
+ load_dotenv()
14
+
15
  from sentence_transformers import SentenceTransformer, CrossEncoder
16
  from langchain_groq import ChatGroq
17
  from langchain_core.prompts import ChatPromptTemplate
 
23
  BM25Okapi = None
24
 
25
  # ---------------------------
26
+ # Environment Variables & Setup
27
  # ---------------------------
28
+ # GitHub API key (required for GitHub API calls)
29
+ GITHUB_API_KEY = os.getenv("GITHUB_API_KEY")
30
+ # GROQ API key (if required by ChatGroq)
31
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
32
+ # HuggingFace token (if you need it to load private models from HuggingFace)
33
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
34
+
35
  CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
36
+
37
+ # Set up a persistent session for GitHub API requests.
38
  session = requests.Session()
39
  session.headers.update({
40
+ "Authorization": f"token {GITHUB_API_KEY}",
41
  "Accept": "application/vnd.github.v3+json"
42
  })
43
 
 
49
  temperature=0.3,
50
  max_tokens=512,
51
  max_retries=3,
52
+ api_key=GROQ_API_KEY # Pass GROQ_API_KEY if the ChatGroq library supports it.
53
  )
54
+
55
  prompt = ChatPromptTemplate.from_messages([
56
  ("system",
57
  """You are a GitHub search optimization expert.
 
129
  # ---------------------------
130
  # GitHub API Helper Functions
131
  # ---------------------------
132
+ def fetch_readme_content(repo_full_name: str) -> str:
133
  readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
134
  response = session.get(readme_url)
135
  if response.status_code == 200:
 
140
  return ""
141
  return ""
142
 
143
+ def fetch_markdown_contents(repo_full_name: str) -> str:
144
  url = f"https://api.github.com/repos/{repo_full_name}/contents"
145
  response = session.get(url)
146
  contents = ""
 
155
  contents += "\n" + file_resp.text
156
  return contents
157
 
158
+ def fetch_all_markdown(repo_full_name: str) -> str:
159
  readme = fetch_readme_content(repo_full_name)
160
  other_md = fetch_markdown_contents(repo_full_name)
161
  return readme + "\n" + other_md
162
 
163
+ def fetch_github_repositories(query: str, max_results: int = 10) -> list:
164
  url = "https://api.github.com/search/repositories"
165
  params = {
166
  "q": query,
 
187
  # Dense Retrieval Model Setup
188
  # ---------------------------
189
  try:
190
+ # If using a GPU-enabled model, the HuggingFace token can be used for private models.
191
  model = SentenceTransformer('all-mpnet-base-v2', device='cuda')
192
  except Exception as e:
193
  print("Error initializing GPU for SentenceTransformer; falling back to CPU:", e)
194
  model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
195
+
196
+ def robust_min_max_norm(scores: np.ndarray) -> np.ndarray:
197
  min_val = scores.min()
198
  max_val = scores.max()
199
  if max_val - min_val < 1e-10:
 
203
  # ---------------------------
204
  # Cross-Encoder Re-Ranking Function
205
  # ---------------------------
206
+ def cross_encoder_rerank_candidates(candidates: list, query: str, model_name: str, top_n: int = 10) -> list:
207
  try:
208
  cross_encoder = CrossEncoder(model_name, device='cuda')
209
  except Exception as e:
210
  print("Error initializing CrossEncoder on GPU; falling back to CPU:", e)
211
+ cross_encoder = CrossEncoder(model_name, device='cpu')
212
+
213
  CHUNK_SIZE = 2000
214
  MAX_DOC_LENGTH = 5000
215
  MIN_DOC_LENGTH = 200
216
 
217
+ def split_text(text: str, chunk_size: int = CHUNK_SIZE) -> list:
218
  return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
219
 
220
  for candidate in candidates:
 
229
  chunks = split_text(doc)
230
  pairs = [[query, chunk] for chunk in chunks]
231
  scores = cross_encoder.predict(pairs)
232
+ max_score = np.max(scores) if scores else 0.0
233
+ avg_score = np.mean(scores) if scores else 0.0
234
  candidate["cross_encoder_score"] = float(0.5 * max_score + 0.5 * avg_score)
235
  except Exception as e:
236
  logging.error(f"Error scoring candidate {candidate.get('link', 'unknown')}: {e}")
237
  candidate["cross_encoder_score"] = 0.0
238
+
239
  all_scores = [candidate["cross_encoder_score"] for candidate in candidates]
240
  if all_scores:
241
  min_score = min(all_scores)
 
243
  for candidate in candidates:
244
  candidate["cross_encoder_score"] += -min_score
245
 
 
246
  return candidates
247
 
248
  # ---------------------------
 
333
 
334
  # Step 9: Compute cross-encoder scores for the top candidates.
335
  top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories
 
336
  cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=len(top_candidates))
337
 
338
+ # Combine both metrics: final_score = w1 * combined_score + w2 * cross_encoder_score.
 
339
  w1 = 0.7
340
  w2 = 0.3
341
  for candidate in top_candidates: