File size: 4,861 Bytes
12ae336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b744c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b0b6fd
9b744c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Module which builds embeddings for issues and pull requests

The module is designed to be run from the command line and takes the following arguments:

        --input_filename: The name of the file containing the issues and pull requests
        --model_id: The name of the sentence transformer model to use
        --issue_type: The type of issue to embed (either "issue" or "pull")
        --n_issues: The number of issues to embed
        --update: Whether to update the existing embeddings

The module saves the embeddings to a file called <issue_type>_embeddings.npy and the index to a file called
embedding_index_to_<issue_type>.json

The index provides a mapping from the index of the embedding to the issue or pull request number.

"""

import argparse
import json
import logging
import os

import numpy as np
from sentence_transformers import SentenceTransformer

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)


def load_model(model_id: str):
    return SentenceTransformer(model_id)


class EmbeddingWriter:
    def __init__(self, output_embedding_filename, output_index_filename, update, embedding_to_issue_index) -> None:
        self.output_embedding_filename = output_embedding_filename
        self.output_index_filename = output_index_filename
        self.embeddings = []
        self.embedding_to_issue_index = embedding_to_issue_index
        self.update = update

    def __enter__(self):
        return self.embeddings

    def __exit__(self, exc_type, exc_val, exc_tb):
        if len(self.embeddings) == 0:
            return

        embeddings = np.array(self.embeddings)

        if self.update and os.path.exists(self.output_embedding_filename):
            embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings])

        logger.info(f"Saving embeddings to {self.output_embedding_filename}")
        np.save(self.output_embedding_filename, embeddings)

        logger.info(f"Saving embedding index to {self.output_index_filename}")
        with open(self.output_index_filename, "w") as f:
            json.dump(self.embedding_to_issue_index, f, indent=4)


def embed_issues(
    input_filename: str,
    model_id: str,
    issue_type: str,
    n_issues: int = -1,
    update: bool = False
):
    model = load_model(model_id)

    output_embedding_filename = f"{issue_type}_embeddings.npy"
    output_index_filename = f"embedding_index_to_{issue_type}.json"

    with open(input_filename, "r") as f:
        issues = json.load(f)

    if update and os.path.exists(output_index_filename):
        with open(output_index_filename, "r") as f:
            embedding_to_issue_index = json.load(f)
        embedding_index = len(embedding_to_issue_index)
    else:
        embedding_to_issue_index = {}
        embedding_index = 0

    max_issues = n_issues if n_issues > 0 else len(issues)
    n_issues = 0

    with EmbeddingWriter(
        output_embedding_filename=output_embedding_filename,
        output_index_filename=output_index_filename,
        update=update,
        embedding_to_issue_index=embedding_to_issue_index
    ) as embeddings: #, embedding_to_issue_index:
        for issue_id, issue in issues.items():
            if n_issues >= max_issues:
                break

            if issue_id in embedding_to_issue_index.values() and update:
                logger.info(f"Skipping issue {issue_id} as it is already embedded")
                continue

            if "body" not in issue:
                logger.info(f"Skipping issue {issue_id} as it has no body")
                continue

            if issue_type == "pull" and "pull_request" not in issue:
                logger.info(f"Skipping issue {issue_id} as it is not a pull request")
                continue

            elif issue_type == "issue" and "pull_request" in issue:
                logger.info(f"Skipping issue {issue_id} as it is a pull request")
                continue

            title = issue["title"] if issue["title"] is not None else ""
            body = issue["body"] if issue["body"] is not None else ""

            logger.info(f"Embedding issue {issue_id}")
            embedding = model.encode(title + "\n" + body)
            embedding_to_issue_index[embedding_index] = issue_id
            embeddings.append(embedding)
            embedding_index += 1
            n_issues += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue')
    parser.add_argument("--input_filename", type=str, default="issues_dict.json")
    parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2")
    parser.add_argument("--n_issues", type=int, default=-1)
    parser.add_argument("--update", action="store_true")
    args = parser.parse_args()
    embed_issues(**vars(args))