from fastapi import FastAPI, Request, Form from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from pydantic import BaseModel from typing import List import os # ← add this from clearml import Model, Task import torch from configs import add_args from models import build_or_load_gen_model import argparse from argparse import Namespace from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig # maximum token length for inputs MAX_SOURCE_LENGTH = 512 # Load endpoints & creds CLEARML_API_HOST = os.environ["CLEARML_API_HOST"] CLEARML_WEB_HOST = os.environ["CLEARML_WEB_HOST"] CLEARML_FILES_HOST = os.environ["CLEARML_FILES_HOST"] CLEARML_ACCESS_KEY = os.environ["CLEARML_API_ACCESS_KEY"] CLEARML_SECRET_KEY = os.environ["CLEARML_API_SECRET_KEY"] # Apply to SDK Task.set_credentials( api_host=CLEARML_API_HOST, web_host=CLEARML_WEB_HOST, files_host=CLEARML_FILES_HOST, key=CLEARML_ACCESS_KEY, secret=CLEARML_SECRET_KEY, ) def pad_assert(tokenizer, source_ids): source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] pad_len = MAX_SOURCE_LENGTH - len(source_ids) source_ids += [tokenizer.pad_id] * pad_len assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." return source_ids # Encode code content and comment into model input def encode_diff(tokenizer, code, comment): # Tokenize code file content code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] # Tokenize comment comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] # Concatenate: [BOS] + code + [EOS] + [msg_id] + comment source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id] source_ids += [tokenizer.msg_id] + comment_ids # Pad/truncate to fixed length source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] pad_len = MAX_SOURCE_LENGTH - len(source_ids) source_ids += [tokenizer.pad_id] * pad_len assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." return source_ids # Load base model architecture and tokenizer from HuggingFace BASE_MODEL_NAME = "microsoft/codereviewer" args = Namespace( model_name_or_path=BASE_MODEL_NAME, load_model_path=None, ) print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}") config, base_model, tokenizer = build_or_load_gen_model(args) print("Base model architecture and tokenizer loaded.") # Download the fine-tuned weights via ClearML using your injected creds task = Task.get_task(task_id="2d65a9e213ea49a9b37e1cc89a2b7ff0") extracted_adapter_dir = task.artifacts["lora-adapter"].get_local_copy() # This is the directory path actual_weights_file_path = os.path.join(extracted_adapter_dir, "pytorch_model.bin") # Path to the actual model file print(f"Fine-tuned adapter weights downloaded and extracted to directory: {extracted_adapter_dir}") print(f"Loading fine-tuned adapter weights from file: {actual_weights_file_path}") # Create LoRA configuration matching the fine-tuned checkpoint lora_cfg = LoraConfig( r=64, lora_alpha=128, target_modules=["q", "wo", "wi", "v", "o", "k"], lora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM" ) # Wrap base model with PEFT LoRA peft_model = get_peft_model(base_model, lora_cfg) # Load adapter-only weights and merge into base adapter_state = torch.load(actual_weights_file_path, map_location="cpu") peft_model.load_state_dict(adapter_state, strict=False) model = peft_model.merge_and_unload() print("Merged base model with LoRA adapters.") model.to("cpu") model.eval() print("Model ready for inference.") app = FastAPI() last_payload = {"comment": "", "files": []} last_infer_result = {"generated_code": ""} class FileContent(BaseModel): filename: str content: str class PRPayload(BaseModel): comment: str files: List[FileContent] class InferenceRequest(BaseModel): comment: str files: List[FileContent] @app.get("/") def root(): return {"message": "FastAPI PR comment service is running"} @app.post("/pr-comments") async def receive_pr_comment(payload: PRPayload): global last_payload last_payload = payload.dict() # Return the received payload as JSON and also redirect to /show return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"}) @app.get("/show", response_class=HTMLResponse) def show_last_comment(): html = f"
{last_payload['comment']}
{file['content']}
{last_infer_result['generated_code']}" return html if __name__ == "__main__": # Place any CLI/training logic here if needed # This block is NOT executed when running with uvicorn pass