File size: 6,103 Bytes
8f9f548
 
 
 
4f7f009
204586a
3bc689a
8f9f548
 
 
 
 
 
 
204586a
 
 
add2495
 
 
 
4f7f009
 
 
add2495
3bc689a
add2495
 
 
3bc689a
 
 
 
8f9f548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f7f009
b50ad21
 
204586a
8f9f548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f6ed77
 
8f9f548
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from pydantic import BaseModel
from typing import List
import os                                    # ← add this

from clearml import Model, Task
import torch
from configs import add_args
from models import build_or_load_gen_model
import argparse
from argparse import Namespace
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig

# maximum token length for inputs
MAX_SOURCE_LENGTH = 512

# Load endpoints & creds
CLEARML_API_HOST     = os.environ["CLEARML_API_HOST"]
CLEARML_WEB_HOST     = os.environ["CLEARML_WEB_HOST"]
CLEARML_FILES_HOST   = os.environ["CLEARML_FILES_HOST"]
CLEARML_ACCESS_KEY   = os.environ["CLEARML_API_ACCESS_KEY"]
CLEARML_SECRET_KEY   = os.environ["CLEARML_API_SECRET_KEY"]

# Apply to SDK
Task.set_credentials(
    api_host=CLEARML_API_HOST,
    web_host=CLEARML_WEB_HOST,
    files_host=CLEARML_FILES_HOST,
    key=CLEARML_ACCESS_KEY,
    secret=CLEARML_SECRET_KEY,
)


def pad_assert(tokenizer, source_ids):
    source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
    source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
    pad_len = MAX_SOURCE_LENGTH - len(source_ids)
    source_ids += [tokenizer.pad_id] * pad_len
    assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
    return source_ids

# Encode code content and comment into model input
def encode_diff(tokenizer, code, comment):
    # Tokenize code file content
    code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
    # Tokenize comment
    comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
    # Concatenate: [BOS] + code + [EOS] + [msg_id] + comment
    source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id]
    source_ids += [tokenizer.msg_id] + comment_ids
    # Pad/truncate to fixed length
    source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
    source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
    pad_len = MAX_SOURCE_LENGTH - len(source_ids)
    source_ids += [tokenizer.pad_id] * pad_len
    assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
    return source_ids

# Load base model architecture and tokenizer from HuggingFace
BASE_MODEL_NAME = "microsoft/codereviewer"
args = Namespace(
    model_name_or_path=BASE_MODEL_NAME,
    load_model_path=None,
)
print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}")
config, base_model, tokenizer = build_or_load_gen_model(args)
print("Base model architecture and tokenizer loaded.")

# Download the fine-tuned weights via ClearML using your injected creds
task = Task.get_task(task_id="9cc33fb4d1d54378b691188c5e230253")
finetuned_weights_path = task.artifacts["lora-pytorch-bin"].get_local_copy()
print(f"Fine-tuned adapter weights downloaded to directory: {os.path.dirname(finetuned_weights_path)}")

# Create LoRA configuration matching the fine-tuned checkpoint
lora_cfg = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q", "wo", "wi", "v", "o", "k"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)
# Wrap base model with PEFT LoRA
peft_model = get_peft_model(base_model, lora_cfg)
# Load adapter-only weights and merge into base
adapter_state = torch.load(finetuned_weights_path, map_location="cpu")
peft_model.load_state_dict(adapter_state, strict=False)
model = peft_model.merge_and_unload()
print("Merged base model with LoRA adapters.")

model.to("cpu")
model.eval()
print("Model ready for inference.")

app = FastAPI()

last_payload = {"comment": "", "files": []}
last_infer_result = {"generated_code": ""}

class FileContent(BaseModel):
    filename: str
    content: str

class PRPayload(BaseModel):
    comment: str
    files: List[FileContent]

class InferenceRequest(BaseModel):
    comment: str
    files: List[FileContent]


@app.get("/")
def root():
    return {"message": "FastAPI PR comment service is running"}

@app.post("/pr-comments")
async def receive_pr_comment(payload: PRPayload):
    global last_payload
    last_payload = payload.dict()
    # Return the received payload as JSON and also redirect to /show
    return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"})

@app.get("/show", response_class=HTMLResponse)
def show_last_comment():
    html = f"<h2>Received Comment</h2><p>{last_payload['comment']}</p><hr>"
    for file in last_payload["files"]:
        html += f"<h3>{file['filename']}</h3><pre>{file['content']}</pre><hr>"
    return html

@app.post("/infer")
async def infer(request: InferenceRequest):
    global last_infer_result
    print("[DEBUG] Received /infer request with:", request.dict())
    
    code = request.files[0].content if request.files else ""
    source_ids = encode_diff(tokenizer, code, request.comment)
    # print("[DEBUG] source_ids:", source_ids)
    #tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids]
    #print("[DEBUG] tokens:", tokens)
    inputs = torch.tensor([source_ids], dtype=torch.long)
    inputs_mask = inputs.ne(tokenizer.pad_id)
    
    preds = model.generate(
        inputs,
        attention_mask=inputs_mask,
        use_cache=True,
        num_beams=5,
        early_stopping=True,
        max_length=100,
        num_return_sequences=1
    )
    
    pred = preds[0].cpu().numpy()
    pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
    # Replace <add> markers with newlines and strip whitespace
    pred_nl = "\n".join([seg.strip() for seg in pred_nl.split("<add>") if seg.strip()])
    last_infer_result = {"generated_code": pred_nl}
    return last_infer_result

@app.get("/show-infer", response_class=HTMLResponse)
def show_infer_result():
    html = f"<h2>Generated Message</h2><pre>{last_infer_result['generated_code']}</pre>"
    return html

if __name__ == "__main__":
    # Place any CLI/training logic here if needed
    # This block is NOT executed when running with uvicorn
    pass