juanwisz's picture
Update app.py
a19817b verified
import gradio as gr
import torch
import numpy as np
from sentence_transformers import SentenceTransformer, util
# 1. Load your fine-tuned retrieval model (on CodeSearchNet - Python)
# This is the model you pushed to the Hugging Face Hub after training.
model_name = "juanwisz/modernbert-python-code-retrieval"
device = "cuda" if torch.cuda.is_available() else "cpu"
# SentenceTransformer automatically handles tokenizer + embedding
embedding_model = SentenceTransformer(model_name, device=device)
# 2. Define a function to:
# - Parse code snippets from the text box (split by "---")
# - Compute embeddings for the user’s query and each snippet
# - Return the top 3 most relevant code snippets based on cosine similarity
def retrieve_top_snippets(query, code_input):
# Split the code snippets by "---"
# Each snippet is trimmed for cleanliness
snippets = [s.strip() for s in code_input.split("---") if s.strip()]
# Edge-case: if user provided no code, just return
if len(snippets) == 0:
return "No code snippets detected (make sure to separate them with ---)."
# Embed the query and code snippets
query_emb = embedding_model.encode(query, convert_to_tensor=True)
snippets_emb = embedding_model.encode(snippets, convert_to_tensor=True)
# Compute cosine similarities [batch_size x 1] with all code snippets
cos_scores = util.cos_sim(query_emb, snippets_emb)[0]
# Sort results by decreasing score
# argsort(descending) means the first indices are the most relevant
top_indices = torch.topk(cos_scores, k=min(3, len(snippets))).indices
# Prepare text output with top 3 matches
results = []
for idx in top_indices:
score = cos_scores[idx].item()
snippet_text = snippets[idx]
results.append(f"**Score**: {score:.4f}\n```python\n{snippet_text}\n```")
# Join all results nicely
return "\n\n".join(results)
#####################
### Gradio Layout ###
#####################
css = """
#container {
margin: 0 auto;
max-width: 700px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# Code Retrieval using ModernBERT\n"
"Enter a natural language query and paste multiple Python code snippets, "
"delimited by `---`. We'll return the top 3 matches.")
with gr.Column(elem_id="container"):
with gr.Row():
query_input = gr.Textbox(
label="Natural Language Query",
placeholder="What does your function do? e.g., 'Parse JSON from a string'"
)
code_snippets_input = gr.Textbox(
label="Paste Python functions (delimited by ---)",
lines=10,
placeholder="Example:\n---\ndef parse_json(data):\n return json.loads(data)\n---\ndef add_numbers(a, b):\n return a + b\n---"
)
search_btn = gr.Button("Search", variant="primary")
results_output = gr.Markdown(label="Top 3 Matches")
# On click, run our retrieval function
search_btn.click(
fn=retrieve_top_snippets,
inputs=[query_input, code_snippets_input],
outputs=results_output
)
if __name__ == "__main__":
demo.launch()