Spaces:
Sleeping
Sleeping
File size: 4,861 Bytes
12ae336 9b744c5 6b0b6fd 9b744c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
"""
Module which builds embeddings for issues and pull requests
The module is designed to be run from the command line and takes the following arguments:
--input_filename: The name of the file containing the issues and pull requests
--model_id: The name of the sentence transformer model to use
--issue_type: The type of issue to embed (either "issue" or "pull")
--n_issues: The number of issues to embed
--update: Whether to update the existing embeddings
The module saves the embeddings to a file called <issue_type>_embeddings.npy and the index to a file called
embedding_index_to_<issue_type>.json
The index provides a mapping from the index of the embedding to the issue or pull request number.
"""
import argparse
import json
import logging
import os
import numpy as np
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_model(model_id: str):
return SentenceTransformer(model_id)
class EmbeddingWriter:
def __init__(self, output_embedding_filename, output_index_filename, update, embedding_to_issue_index) -> None:
self.output_embedding_filename = output_embedding_filename
self.output_index_filename = output_index_filename
self.embeddings = []
self.embedding_to_issue_index = embedding_to_issue_index
self.update = update
def __enter__(self):
return self.embeddings
def __exit__(self, exc_type, exc_val, exc_tb):
if len(self.embeddings) == 0:
return
embeddings = np.array(self.embeddings)
if self.update and os.path.exists(self.output_embedding_filename):
embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings])
logger.info(f"Saving embeddings to {self.output_embedding_filename}")
np.save(self.output_embedding_filename, embeddings)
logger.info(f"Saving embedding index to {self.output_index_filename}")
with open(self.output_index_filename, "w") as f:
json.dump(self.embedding_to_issue_index, f, indent=4)
def embed_issues(
input_filename: str,
model_id: str,
issue_type: str,
n_issues: int = -1,
update: bool = False
):
model = load_model(model_id)
output_embedding_filename = f"{issue_type}_embeddings.npy"
output_index_filename = f"embedding_index_to_{issue_type}.json"
with open(input_filename, "r") as f:
issues = json.load(f)
if update and os.path.exists(output_index_filename):
with open(output_index_filename, "r") as f:
embedding_to_issue_index = json.load(f)
embedding_index = len(embedding_to_issue_index)
else:
embedding_to_issue_index = {}
embedding_index = 0
max_issues = n_issues if n_issues > 0 else len(issues)
n_issues = 0
with EmbeddingWriter(
output_embedding_filename=output_embedding_filename,
output_index_filename=output_index_filename,
update=update,
embedding_to_issue_index=embedding_to_issue_index
) as embeddings: #, embedding_to_issue_index:
for issue_id, issue in issues.items():
if n_issues >= max_issues:
break
if issue_id in embedding_to_issue_index.values() and update:
logger.info(f"Skipping issue {issue_id} as it is already embedded")
continue
if "body" not in issue:
logger.info(f"Skipping issue {issue_id} as it has no body")
continue
if issue_type == "pull" and "pull_request" not in issue:
logger.info(f"Skipping issue {issue_id} as it is not a pull request")
continue
elif issue_type == "issue" and "pull_request" in issue:
logger.info(f"Skipping issue {issue_id} as it is a pull request")
continue
title = issue["title"] if issue["title"] is not None else ""
body = issue["body"] if issue["body"] is not None else ""
logger.info(f"Embedding issue {issue_id}")
embedding = model.encode(title + "\n" + body)
embedding_to_issue_index[embedding_index] = issue_id
embeddings.append(embedding)
embedding_index += 1
n_issues += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue')
parser.add_argument("--input_filename", type=str, default="issues_dict.json")
parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2")
parser.add_argument("--n_issues", type=int, default=-1)
parser.add_argument("--update", action="store_true")
args = parser.parse_args()
embed_issues(**vars(args))
|