Spaces:
Running
Running
File size: 6,103 Bytes
8f9f548 4f7f009 204586a 3bc689a 8f9f548 204586a add2495 4f7f009 add2495 3bc689a add2495 3bc689a 8f9f548 4f7f009 b50ad21 204586a 8f9f548 7f6ed77 8f9f548 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from pydantic import BaseModel
from typing import List
import os # ← add this
from clearml import Model, Task
import torch
from configs import add_args
from models import build_or_load_gen_model
import argparse
from argparse import Namespace
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
# maximum token length for inputs
MAX_SOURCE_LENGTH = 512
# Load endpoints & creds
CLEARML_API_HOST = os.environ["CLEARML_API_HOST"]
CLEARML_WEB_HOST = os.environ["CLEARML_WEB_HOST"]
CLEARML_FILES_HOST = os.environ["CLEARML_FILES_HOST"]
CLEARML_ACCESS_KEY = os.environ["CLEARML_API_ACCESS_KEY"]
CLEARML_SECRET_KEY = os.environ["CLEARML_API_SECRET_KEY"]
# Apply to SDK
Task.set_credentials(
api_host=CLEARML_API_HOST,
web_host=CLEARML_WEB_HOST,
files_host=CLEARML_FILES_HOST,
key=CLEARML_ACCESS_KEY,
secret=CLEARML_SECRET_KEY,
)
def pad_assert(tokenizer, source_ids):
source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
pad_len = MAX_SOURCE_LENGTH - len(source_ids)
source_ids += [tokenizer.pad_id] * pad_len
assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
return source_ids
# Encode code content and comment into model input
def encode_diff(tokenizer, code, comment):
# Tokenize code file content
code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
# Tokenize comment
comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
# Concatenate: [BOS] + code + [EOS] + [msg_id] + comment
source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id]
source_ids += [tokenizer.msg_id] + comment_ids
# Pad/truncate to fixed length
source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
pad_len = MAX_SOURCE_LENGTH - len(source_ids)
source_ids += [tokenizer.pad_id] * pad_len
assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
return source_ids
# Load base model architecture and tokenizer from HuggingFace
BASE_MODEL_NAME = "microsoft/codereviewer"
args = Namespace(
model_name_or_path=BASE_MODEL_NAME,
load_model_path=None,
)
print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}")
config, base_model, tokenizer = build_or_load_gen_model(args)
print("Base model architecture and tokenizer loaded.")
# Download the fine-tuned weights via ClearML using your injected creds
task = Task.get_task(task_id="9cc33fb4d1d54378b691188c5e230253")
finetuned_weights_path = task.artifacts["lora-pytorch-bin"].get_local_copy()
print(f"Fine-tuned adapter weights downloaded to directory: {os.path.dirname(finetuned_weights_path)}")
# Create LoRA configuration matching the fine-tuned checkpoint
lora_cfg = LoraConfig(
r=64,
lora_alpha=128,
target_modules=["q", "wo", "wi", "v", "o", "k"],
lora_dropout=0.05,
bias="none",
task_type="SEQ_2_SEQ_LM"
)
# Wrap base model with PEFT LoRA
peft_model = get_peft_model(base_model, lora_cfg)
# Load adapter-only weights and merge into base
adapter_state = torch.load(finetuned_weights_path, map_location="cpu")
peft_model.load_state_dict(adapter_state, strict=False)
model = peft_model.merge_and_unload()
print("Merged base model with LoRA adapters.")
model.to("cpu")
model.eval()
print("Model ready for inference.")
app = FastAPI()
last_payload = {"comment": "", "files": []}
last_infer_result = {"generated_code": ""}
class FileContent(BaseModel):
filename: str
content: str
class PRPayload(BaseModel):
comment: str
files: List[FileContent]
class InferenceRequest(BaseModel):
comment: str
files: List[FileContent]
@app.get("/")
def root():
return {"message": "FastAPI PR comment service is running"}
@app.post("/pr-comments")
async def receive_pr_comment(payload: PRPayload):
global last_payload
last_payload = payload.dict()
# Return the received payload as JSON and also redirect to /show
return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"})
@app.get("/show", response_class=HTMLResponse)
def show_last_comment():
html = f"<h2>Received Comment</h2><p>{last_payload['comment']}</p><hr>"
for file in last_payload["files"]:
html += f"<h3>{file['filename']}</h3><pre>{file['content']}</pre><hr>"
return html
@app.post("/infer")
async def infer(request: InferenceRequest):
global last_infer_result
print("[DEBUG] Received /infer request with:", request.dict())
code = request.files[0].content if request.files else ""
source_ids = encode_diff(tokenizer, code, request.comment)
# print("[DEBUG] source_ids:", source_ids)
#tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids]
#print("[DEBUG] tokens:", tokens)
inputs = torch.tensor([source_ids], dtype=torch.long)
inputs_mask = inputs.ne(tokenizer.pad_id)
preds = model.generate(
inputs,
attention_mask=inputs_mask,
use_cache=True,
num_beams=5,
early_stopping=True,
max_length=100,
num_return_sequences=1
)
pred = preds[0].cpu().numpy()
pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
# Replace <add> markers with newlines and strip whitespace
pred_nl = "\n".join([seg.strip() for seg in pred_nl.split("<add>") if seg.strip()])
last_infer_result = {"generated_code": pred_nl}
return last_infer_result
@app.get("/show-infer", response_class=HTMLResponse)
def show_infer_result():
html = f"<h2>Generated Message</h2><pre>{last_infer_result['generated_code']}</pre>"
return html
if __name__ == "__main__":
# Place any CLI/training logic here if needed
# This block is NOT executed when running with uvicorn
pass |