Spaces:
Running
Running
import gradio as gr | |
import requests | |
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM | |
import torch | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/codereviewer") | |
tokenizer.special_dict = { | |
f"<e{i}>": tokenizer.get_vocab()[f"<e{i}>"] for i in range(99, -1, -1) | |
} | |
tokenizer.mask_id = tokenizer.get_vocab()["<mask>"] | |
tokenizer.bos_id = tokenizer.get_vocab()["<s>"] | |
tokenizer.pad_id = tokenizer.get_vocab()["<pad>"] | |
tokenizer.eos_id = tokenizer.get_vocab()["</s>"] | |
tokenizer.msg_id = tokenizer.get_vocab()["<msg>"] | |
tokenizer.keep_id = tokenizer.get_vocab()["<keep>"] | |
tokenizer.add_id = tokenizer.get_vocab()["<add>"] | |
tokenizer.del_id = tokenizer.get_vocab()["<del>"] | |
tokenizer.start_id = tokenizer.get_vocab()["<start>"] | |
tokenizer.end_id = tokenizer.get_vocab()["<end>"] | |
model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/codereviewer") | |
model.eval() | |
MAX_SOURCE_LENGTH = 512 | |
def pad_assert(tokenizer, source_ids): | |
source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] | |
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
pad_len = MAX_SOURCE_LENGTH - len(source_ids) | |
source_ids += [tokenizer.pad_id] * pad_len | |
assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." | |
return source_ids | |
def encode_diff(tokenizer, diff, msg, source): | |
difflines = diff.split("\n")[1:] # remove start @@ | |
difflines = [line for line in difflines if len(line.strip()) > 0] | |
map_dic = {"-": 0, "+": 1, " ": 2} | |
def f(s): | |
if s in map_dic: | |
return map_dic[s] | |
else: | |
return 2 | |
labels = [f(line[0]) for line in difflines] | |
difflines = [line[1:].strip() for line in difflines] | |
inputstr = "<s>" + source + "</s>" | |
inputstr += "<msg>" + msg | |
for label, line in zip(labels, difflines): | |
if label == 1: | |
inputstr += "<add>" + line | |
elif label == 0: | |
inputstr += "<del>" + line | |
else: | |
inputstr += "<keep>" + line | |
source_ids = tokenizer.encode(inputstr, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] | |
source_ids = pad_assert(tokenizer, source_ids) | |
return source_ids | |
class FileDiffs(object): | |
def __init__(self, diff_string): | |
diff_array = diff_string.split("\n") | |
self.file_name = diff_array[0] | |
self.file_path = self.file_name.split("a/", 1)[1].rsplit("b/", 1)[0] | |
self.diffs = list() | |
for line in diff_array[4:]: | |
if line.startswith("@@"): | |
self.diffs.append(str()) | |
self.diffs[-1] += "\n" + line | |
def review_commit(user, repository, commit): | |
commit_metadata = requests.get(F"https://api.github.com/repos/{user}/{repository}/commits/{commit}").json() | |
msg = commit_metadata["commit"]["message"] | |
diff_data = requests.get(F"https://api.github.com/repos/{user}/{repository}/commits/{commit}", headers={"Accept":"application/vnd.github.diff"}) | |
code_diff = diff_data.text | |
files_diffs = list() | |
for file in code_diff.split("diff --git"): | |
if len(file) > 0: | |
fd = FileDiffs(file) | |
files_diffs.append(fd) | |
output = "" | |
for fd in files_diffs: | |
output += F"File:{fd.file_path}\n" | |
source = requests.get(F"https://raw.githubusercontent.com/{user}/{repository}/^{commit}/{fd.file_path}").text | |
for diff in fd.diffs: | |
inputs = torch.tensor([encode_diff(tokenizer, diff, msg, source)], dtype=torch.long).to("cpu") | |
inputs_mask = inputs.ne(tokenizer.pad_id) | |
preds = model.generate(inputs, | |
attention_mask=inputs_mask, | |
use_cache=True, | |
num_beams=5, | |
early_stopping=True, | |
max_length=100, | |
num_return_sequences=2 | |
) | |
preds = list(preds.cpu().numpy()) | |
pred_nls = [tokenizer.decode(id[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in | |
preds] | |
output += diff + "\n#######\nComment:\n#######\n" + pred_nls[0] + "\n#######\n" | |
return output | |
iface = gr.Interface(fn=review_commit, inputs=["text", "text", "text"], outputs="text") | |
iface.launch() | |