Spaces:
Running
on
Zero
Running
on
Zero
zamalali
commited on
Commit
·
94ed277
1
Parent(s):
27298b1
Refactor app.py and main.py for improved readability and functionality; add environment variable loading
Browse files- __pycache__/main.cpython-311.pyc +0 -0
- app.py +4 -8
- main.py +33 -20
__pycache__/main.cpython-311.pyc
ADDED
Binary file (21.6 kB). View file
|
|
app.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import spaces
|
2 |
-
|
3 |
-
|
4 |
import gradio as gr
|
5 |
import time
|
6 |
-
from gradio.themes.utils import sizes
|
7 |
import threading
|
8 |
import logging
|
9 |
-
from
|
|
|
|
|
10 |
# ---------------------------
|
11 |
# Global Logging Buffer Setup
|
12 |
# ---------------------------
|
@@ -45,8 +44,7 @@ def parse_result_to_html(raw_result: str) -> str:
|
|
45 |
Only the top 10 results are displayed.
|
46 |
"""
|
47 |
entries = raw_result.strip().split("Final Rank:")
|
48 |
-
#
|
49 |
-
entries = entries[1:11]
|
50 |
if not entries:
|
51 |
return "<p>No repositories found for your query.</p>"
|
52 |
html = """
|
@@ -163,8 +161,6 @@ with gr.Blocks(
|
|
163 |
elem_id="header"
|
164 |
)
|
165 |
|
166 |
-
|
167 |
-
# Centered main container for inputs and outputs.
|
168 |
with gr.Column(elem_id="main-container"):
|
169 |
research_input = gr.Textbox(
|
170 |
label="Research Query",
|
|
|
1 |
import spaces
|
|
|
|
|
2 |
import gradio as gr
|
3 |
import time
|
|
|
4 |
import threading
|
5 |
import logging
|
6 |
+
from gradio.themes.utils import sizes
|
7 |
+
from main import run_repository_ranking # Import the repository ranking function
|
8 |
+
|
9 |
# ---------------------------
|
10 |
# Global Logging Buffer Setup
|
11 |
# ---------------------------
|
|
|
44 |
Only the top 10 results are displayed.
|
45 |
"""
|
46 |
entries = raw_result.strip().split("Final Rank:")
|
47 |
+
entries = entries[1:11] # Use only the first 10 entries
|
|
|
48 |
if not entries:
|
49 |
return "<p>No repositories found for your query.</p>"
|
50 |
html = """
|
|
|
161 |
elem_id="header"
|
162 |
)
|
163 |
|
|
|
|
|
164 |
with gr.Column(elem_id="main-container"):
|
165 |
research_input = gr.Textbox(
|
166 |
label="Research Query",
|
main.py
CHANGED
@@ -6,7 +6,12 @@ import faiss
|
|
6 |
import re
|
7 |
import logging
|
8 |
from pathlib import Path
|
|
|
|
|
|
|
9 |
from dotenv import load_dotenv
|
|
|
|
|
10 |
from sentence_transformers import SentenceTransformer, CrossEncoder
|
11 |
from langchain_groq import ChatGroq
|
12 |
from langchain_core.prompts import ChatPromptTemplate
|
@@ -18,14 +23,21 @@ except ImportError:
|
|
18 |
BM25Okapi = None
|
19 |
|
20 |
# ---------------------------
|
21 |
-
# Environment Setup
|
22 |
# ---------------------------
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
|
25 |
-
|
|
|
26 |
session = requests.Session()
|
27 |
session.headers.update({
|
28 |
-
"Authorization": f"token {
|
29 |
"Accept": "application/vnd.github.v3+json"
|
30 |
})
|
31 |
|
@@ -37,7 +49,9 @@ llm = ChatGroq(
|
|
37 |
temperature=0.3,
|
38 |
max_tokens=512,
|
39 |
max_retries=3,
|
|
|
40 |
)
|
|
|
41 |
prompt = ChatPromptTemplate.from_messages([
|
42 |
("system",
|
43 |
"""You are a GitHub search optimization expert.
|
@@ -115,7 +129,7 @@ def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str
|
|
115 |
# ---------------------------
|
116 |
# GitHub API Helper Functions
|
117 |
# ---------------------------
|
118 |
-
def fetch_readme_content(repo_full_name):
|
119 |
readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
|
120 |
response = session.get(readme_url)
|
121 |
if response.status_code == 200:
|
@@ -126,7 +140,7 @@ def fetch_readme_content(repo_full_name):
|
|
126 |
return ""
|
127 |
return ""
|
128 |
|
129 |
-
def fetch_markdown_contents(repo_full_name):
|
130 |
url = f"https://api.github.com/repos/{repo_full_name}/contents"
|
131 |
response = session.get(url)
|
132 |
contents = ""
|
@@ -141,12 +155,12 @@ def fetch_markdown_contents(repo_full_name):
|
|
141 |
contents += "\n" + file_resp.text
|
142 |
return contents
|
143 |
|
144 |
-
def fetch_all_markdown(repo_full_name):
|
145 |
readme = fetch_readme_content(repo_full_name)
|
146 |
other_md = fetch_markdown_contents(repo_full_name)
|
147 |
return readme + "\n" + other_md
|
148 |
|
149 |
-
def fetch_github_repositories(query, max_results=10):
|
150 |
url = "https://api.github.com/search/repositories"
|
151 |
params = {
|
152 |
"q": query,
|
@@ -173,12 +187,13 @@ def fetch_github_repositories(query, max_results=10):
|
|
173 |
# Dense Retrieval Model Setup
|
174 |
# ---------------------------
|
175 |
try:
|
|
|
176 |
model = SentenceTransformer('all-mpnet-base-v2', device='cuda')
|
177 |
except Exception as e:
|
178 |
print("Error initializing GPU for SentenceTransformer; falling back to CPU:", e)
|
179 |
model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
|
180 |
-
|
181 |
-
def robust_min_max_norm(scores):
|
182 |
min_val = scores.min()
|
183 |
max_val = scores.max()
|
184 |
if max_val - min_val < 1e-10:
|
@@ -188,17 +203,18 @@ def robust_min_max_norm(scores):
|
|
188 |
# ---------------------------
|
189 |
# Cross-Encoder Re-Ranking Function
|
190 |
# ---------------------------
|
191 |
-
def cross_encoder_rerank_candidates(candidates, query, model_name, top_n=10):
|
192 |
try:
|
193 |
cross_encoder = CrossEncoder(model_name, device='cuda')
|
194 |
except Exception as e:
|
195 |
print("Error initializing CrossEncoder on GPU; falling back to CPU:", e)
|
196 |
-
cross_encoder = CrossEncoder(model_name, device='cpu')
|
|
|
197 |
CHUNK_SIZE = 2000
|
198 |
MAX_DOC_LENGTH = 5000
|
199 |
MIN_DOC_LENGTH = 200
|
200 |
|
201 |
-
def split_text(text, chunk_size=CHUNK_SIZE):
|
202 |
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
|
203 |
|
204 |
for candidate in candidates:
|
@@ -213,13 +229,13 @@ def cross_encoder_rerank_candidates(candidates, query, model_name, top_n=10):
|
|
213 |
chunks = split_text(doc)
|
214 |
pairs = [[query, chunk] for chunk in chunks]
|
215 |
scores = cross_encoder.predict(pairs)
|
216 |
-
max_score = np.max(scores) if
|
217 |
-
avg_score = np.mean(scores) if
|
218 |
candidate["cross_encoder_score"] = float(0.5 * max_score + 0.5 * avg_score)
|
219 |
except Exception as e:
|
220 |
logging.error(f"Error scoring candidate {candidate.get('link', 'unknown')}: {e}")
|
221 |
candidate["cross_encoder_score"] = 0.0
|
222 |
-
|
223 |
all_scores = [candidate["cross_encoder_score"] for candidate in candidates]
|
224 |
if all_scores:
|
225 |
min_score = min(all_scores)
|
@@ -227,7 +243,6 @@ def cross_encoder_rerank_candidates(candidates, query, model_name, top_n=10):
|
|
227 |
for candidate in candidates:
|
228 |
candidate["cross_encoder_score"] += -min_score
|
229 |
|
230 |
-
# Do not sort solely by cross-encoder score; we want to combine metrics.
|
231 |
return candidates
|
232 |
|
233 |
# ---------------------------
|
@@ -318,11 +333,9 @@ def run_repository_ranking(query: str) -> str:
|
|
318 |
|
319 |
# Step 9: Compute cross-encoder scores for the top candidates.
|
320 |
top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories
|
321 |
-
# Update candidates with cross-encoder scores.
|
322 |
cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=len(top_candidates))
|
323 |
|
324 |
-
#
|
325 |
-
# Adjust weights as needed (here 0.7 for combined, 0.3 for cross-encoder).
|
326 |
w1 = 0.7
|
327 |
w2 = 0.3
|
328 |
for candidate in top_candidates:
|
|
|
6 |
import re
|
7 |
import logging
|
8 |
from pathlib import Path
|
9 |
+
|
10 |
+
# For local development, load environment variables from a .env file.
|
11 |
+
# In HuggingFace Spaces, secrets are automatically available as environment variables.
|
12 |
from dotenv import load_dotenv
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
from sentence_transformers import SentenceTransformer, CrossEncoder
|
16 |
from langchain_groq import ChatGroq
|
17 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
23 |
BM25Okapi = None
|
24 |
|
25 |
# ---------------------------
|
26 |
+
# Environment Variables & Setup
|
27 |
# ---------------------------
|
28 |
+
# GitHub API key (required for GitHub API calls)
|
29 |
+
GITHUB_API_KEY = os.getenv("GITHUB_API_KEY")
|
30 |
+
# GROQ API key (if required by ChatGroq)
|
31 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
32 |
+
# HuggingFace token (if you need it to load private models from HuggingFace)
|
33 |
+
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
34 |
+
|
35 |
CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
|
36 |
+
|
37 |
+
# Set up a persistent session for GitHub API requests.
|
38 |
session = requests.Session()
|
39 |
session.headers.update({
|
40 |
+
"Authorization": f"token {GITHUB_API_KEY}",
|
41 |
"Accept": "application/vnd.github.v3+json"
|
42 |
})
|
43 |
|
|
|
49 |
temperature=0.3,
|
50 |
max_tokens=512,
|
51 |
max_retries=3,
|
52 |
+
api_key=GROQ_API_KEY # Pass GROQ_API_KEY if the ChatGroq library supports it.
|
53 |
)
|
54 |
+
|
55 |
prompt = ChatPromptTemplate.from_messages([
|
56 |
("system",
|
57 |
"""You are a GitHub search optimization expert.
|
|
|
129 |
# ---------------------------
|
130 |
# GitHub API Helper Functions
|
131 |
# ---------------------------
|
132 |
+
def fetch_readme_content(repo_full_name: str) -> str:
|
133 |
readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
|
134 |
response = session.get(readme_url)
|
135 |
if response.status_code == 200:
|
|
|
140 |
return ""
|
141 |
return ""
|
142 |
|
143 |
+
def fetch_markdown_contents(repo_full_name: str) -> str:
|
144 |
url = f"https://api.github.com/repos/{repo_full_name}/contents"
|
145 |
response = session.get(url)
|
146 |
contents = ""
|
|
|
155 |
contents += "\n" + file_resp.text
|
156 |
return contents
|
157 |
|
158 |
+
def fetch_all_markdown(repo_full_name: str) -> str:
|
159 |
readme = fetch_readme_content(repo_full_name)
|
160 |
other_md = fetch_markdown_contents(repo_full_name)
|
161 |
return readme + "\n" + other_md
|
162 |
|
163 |
+
def fetch_github_repositories(query: str, max_results: int = 10) -> list:
|
164 |
url = "https://api.github.com/search/repositories"
|
165 |
params = {
|
166 |
"q": query,
|
|
|
187 |
# Dense Retrieval Model Setup
|
188 |
# ---------------------------
|
189 |
try:
|
190 |
+
# If using a GPU-enabled model, the HuggingFace token can be used for private models.
|
191 |
model = SentenceTransformer('all-mpnet-base-v2', device='cuda')
|
192 |
except Exception as e:
|
193 |
print("Error initializing GPU for SentenceTransformer; falling back to CPU:", e)
|
194 |
model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
|
195 |
+
|
196 |
+
def robust_min_max_norm(scores: np.ndarray) -> np.ndarray:
|
197 |
min_val = scores.min()
|
198 |
max_val = scores.max()
|
199 |
if max_val - min_val < 1e-10:
|
|
|
203 |
# ---------------------------
|
204 |
# Cross-Encoder Re-Ranking Function
|
205 |
# ---------------------------
|
206 |
+
def cross_encoder_rerank_candidates(candidates: list, query: str, model_name: str, top_n: int = 10) -> list:
|
207 |
try:
|
208 |
cross_encoder = CrossEncoder(model_name, device='cuda')
|
209 |
except Exception as e:
|
210 |
print("Error initializing CrossEncoder on GPU; falling back to CPU:", e)
|
211 |
+
cross_encoder = CrossEncoder(model_name, device='cpu')
|
212 |
+
|
213 |
CHUNK_SIZE = 2000
|
214 |
MAX_DOC_LENGTH = 5000
|
215 |
MIN_DOC_LENGTH = 200
|
216 |
|
217 |
+
def split_text(text: str, chunk_size: int = CHUNK_SIZE) -> list:
|
218 |
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
|
219 |
|
220 |
for candidate in candidates:
|
|
|
229 |
chunks = split_text(doc)
|
230 |
pairs = [[query, chunk] for chunk in chunks]
|
231 |
scores = cross_encoder.predict(pairs)
|
232 |
+
max_score = np.max(scores) if scores else 0.0
|
233 |
+
avg_score = np.mean(scores) if scores else 0.0
|
234 |
candidate["cross_encoder_score"] = float(0.5 * max_score + 0.5 * avg_score)
|
235 |
except Exception as e:
|
236 |
logging.error(f"Error scoring candidate {candidate.get('link', 'unknown')}: {e}")
|
237 |
candidate["cross_encoder_score"] = 0.0
|
238 |
+
|
239 |
all_scores = [candidate["cross_encoder_score"] for candidate in candidates]
|
240 |
if all_scores:
|
241 |
min_score = min(all_scores)
|
|
|
243 |
for candidate in candidates:
|
244 |
candidate["cross_encoder_score"] += -min_score
|
245 |
|
|
|
246 |
return candidates
|
247 |
|
248 |
# ---------------------------
|
|
|
333 |
|
334 |
# Step 9: Compute cross-encoder scores for the top candidates.
|
335 |
top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories
|
|
|
336 |
cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=len(top_candidates))
|
337 |
|
338 |
+
# Combine both metrics: final_score = w1 * combined_score + w2 * cross_encoder_score.
|
|
|
339 |
w1 = 0.7
|
340 |
w2 = 0.3
|
341 |
for candidate in top_candidates:
|