|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
|
|
|
|
|
model_name = "juanwisz/modernbert-python-code-retrieval" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
embedding_model = SentenceTransformer(model_name, device=device) |
|
|
|
|
|
|
|
|
|
|
|
def retrieve_top_snippets(query, code_input): |
|
|
|
|
|
snippets = [s.strip() for s in code_input.split("---") if s.strip()] |
|
|
|
|
|
if len(snippets) == 0: |
|
return "No code snippets detected (make sure to separate them with ---)." |
|
|
|
|
|
query_emb = embedding_model.encode(query, convert_to_tensor=True) |
|
snippets_emb = embedding_model.encode(snippets, convert_to_tensor=True) |
|
|
|
|
|
cos_scores = util.cos_sim(query_emb, snippets_emb)[0] |
|
|
|
|
|
|
|
top_indices = torch.topk(cos_scores, k=min(3, len(snippets))).indices |
|
|
|
|
|
results = [] |
|
for idx in top_indices: |
|
score = cos_scores[idx].item() |
|
snippet_text = snippets[idx] |
|
results.append(f"**Score**: {score:.4f}\n```python\n{snippet_text}\n```") |
|
|
|
|
|
return "\n\n".join(results) |
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
#container { |
|
margin: 0 auto; |
|
max-width: 700px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("# Code Retrieval using ModernBERT\n" |
|
"Enter a natural language query and paste multiple Python code snippets, " |
|
"delimited by `---`. We'll return the top 3 matches.") |
|
|
|
with gr.Column(elem_id="container"): |
|
with gr.Row(): |
|
query_input = gr.Textbox( |
|
label="Natural Language Query", |
|
placeholder="What does your function do? e.g., 'Parse JSON from a string'" |
|
) |
|
|
|
code_snippets_input = gr.Textbox( |
|
label="Paste Python functions (delimited by ---)", |
|
lines=10, |
|
placeholder="Example:\n---\ndef parse_json(data):\n return json.loads(data)\n---\ndef add_numbers(a, b):\n return a + b\n---" |
|
) |
|
|
|
search_btn = gr.Button("Search", variant="primary") |
|
results_output = gr.Markdown(label="Top 3 Matches") |
|
|
|
|
|
search_btn.click( |
|
fn=retrieve_top_snippets, |
|
inputs=[query_input, code_snippets_input], |
|
outputs=results_output |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|