import gradio as gr import torch import numpy as np from sentence_transformers import SentenceTransformer, util # 1. Load your fine-tuned retrieval model (on CodeSearchNet - Python) # This is the model you pushed to the Hugging Face Hub after training. model_name = "juanwisz/modernbert-python-code-retrieval" device = "cuda" if torch.cuda.is_available() else "cpu" # SentenceTransformer automatically handles tokenizer + embedding embedding_model = SentenceTransformer(model_name, device=device) # 2. Define a function to: # - Parse code snippets from the text box (split by "---") # - Compute embeddings for the user’s query and each snippet # - Return the top 3 most relevant code snippets based on cosine similarity def retrieve_top_snippets(query, code_input): # Split the code snippets by "---" # Each snippet is trimmed for cleanliness snippets = [s.strip() for s in code_input.split("---") if s.strip()] # Edge-case: if user provided no code, just return if len(snippets) == 0: return "No code snippets detected (make sure to separate them with ---)." # Embed the query and code snippets query_emb = embedding_model.encode(query, convert_to_tensor=True) snippets_emb = embedding_model.encode(snippets, convert_to_tensor=True) # Compute cosine similarities [batch_size x 1] with all code snippets cos_scores = util.cos_sim(query_emb, snippets_emb)[0] # Sort results by decreasing score # argsort(descending) means the first indices are the most relevant top_indices = torch.topk(cos_scores, k=min(3, len(snippets))).indices # Prepare text output with top 3 matches results = [] for idx in top_indices: score = cos_scores[idx].item() snippet_text = snippets[idx] results.append(f"**Score**: {score:.4f}\n```python\n{snippet_text}\n```") # Join all results nicely return "\n\n".join(results) ##################### ### Gradio Layout ### ##################### css = """ #container { margin: 0 auto; max-width: 700px; } """ with gr.Blocks(css=css) as demo: gr.Markdown("# Code Retrieval using ModernBERT\n" "Enter a natural language query and paste multiple Python code snippets, " "delimited by `---`. We'll return the top 3 matches.") with gr.Column(elem_id="container"): with gr.Row(): query_input = gr.Textbox( label="Natural Language Query", placeholder="What does your function do? e.g., 'Parse JSON from a string'" ) code_snippets_input = gr.Textbox( label="Paste Python functions (delimited by ---)", lines=10, placeholder="Example:\n---\ndef parse_json(data):\n return json.loads(data)\n---\ndef add_numbers(a, b):\n return a + b\n---" ) search_btn = gr.Button("Search", variant="primary") results_output = gr.Markdown(label="Top 3 Matches") # On click, run our retrieval function search_btn.click( fn=retrieve_top_snippets, inputs=[query_input, code_snippets_input], outputs=results_output ) if __name__ == "__main__": demo.launch()