Spaces:
Paused
Paused
AurelioAguirre
commited on
Commit
·
cfaa883
1
Parent(s):
19b1be5
Refactor v2
Browse files- .gitignore +1 -0
- Dockerfile +18 -41
- main/__init__.py +0 -0
- main/env_template +0 -55
- main/main.py +0 -61
- main/routes.py +0 -419
- requirements.txt +35 -45
.gitignore
CHANGED
@@ -42,3 +42,4 @@ wheels/
|
|
42 |
# Logs
|
43 |
*.log
|
44 |
logs/
|
|
|
|
42 |
# Logs
|
43 |
*.log
|
44 |
logs/
|
45 |
+
.cache/
|
Dockerfile
CHANGED
@@ -1,56 +1,33 @@
|
|
1 |
-
#
|
2 |
-
FROM
|
3 |
|
4 |
# Set working directory
|
5 |
-
WORKDIR /
|
6 |
|
7 |
-
# Install
|
8 |
-
RUN apt-get update && \
|
9 |
-
|
|
|
10 |
git \
|
11 |
-
wget \
|
12 |
-
&& apt-get clean \
|
13 |
&& rm -rf /var/lib/apt/lists/*
|
14 |
|
15 |
-
# Create and set permissions for directories
|
16 |
-
RUN mkdir -p /app/.cache/huggingface && \
|
17 |
-
chmod 777 /app/.cache/huggingface && \
|
18 |
-
mkdir -p /app/.git && \
|
19 |
-
chmod 777 /app/.git
|
20 |
-
|
21 |
-
# Set environment variables
|
22 |
-
ENV TRANSFORMERS_CACHE=/app/.cache/huggingface/hub
|
23 |
-
ENV HF_HOME=/app/.cache/huggingface
|
24 |
-
ENV GIT_CONFIG_GLOBAL=/app/.git/config
|
25 |
-
|
26 |
# Copy requirements first to leverage Docker cache
|
27 |
COPY requirements.txt .
|
28 |
|
29 |
# Install Python dependencies
|
30 |
-
RUN
|
31 |
|
32 |
-
#
|
33 |
-
|
34 |
-
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
litgpt download mistralai/Mistral-7B-Instruct-v0.3 \
|
41 |
-
--access_token ${HF_TOKEN} \
|
42 |
-
--checkpoint_dir /app/main/checkpoints || { echo "Download failed with status $?"; exit 1; }
|
43 |
-
|
44 |
-
# Copy the rest of the application
|
45 |
-
COPY . .
|
46 |
-
|
47 |
-
# Set environment variables for the application
|
48 |
-
ENV LLM_ENGINE_HOST=0.0.0.0
|
49 |
-
ENV LLM_ENGINE_PORT=7860
|
50 |
-
ENV MODEL_PATH=/app/main/checkpoints/mistralai/Mistral-7B-Instruct-v0.3
|
51 |
|
52 |
-
# Expose port
|
53 |
-
EXPOSE
|
54 |
|
55 |
# Command to run the application
|
56 |
-
CMD ["
|
|
|
1 |
+
# Start from NVIDIA CUDA base image
|
2 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
3 |
|
4 |
# Set working directory
|
5 |
+
WORKDIR /code
|
6 |
|
7 |
+
# Install system dependencies
|
8 |
+
RUN apt-get update && apt-get install -y \
|
9 |
+
python3.12 \
|
10 |
+
python3-pip \
|
11 |
git \
|
|
|
|
|
12 |
&& rm -rf /var/lib/apt/lists/*
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# Copy requirements first to leverage Docker cache
|
15 |
COPY requirements.txt .
|
16 |
|
17 |
# Install Python dependencies
|
18 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
19 |
|
20 |
+
# Copy the application code
|
21 |
+
COPY ./app /code/app
|
22 |
+
COPY ./utils /code/utils
|
23 |
|
24 |
+
# Set environment variables
|
25 |
+
ENV PYTHONPATH=/code
|
26 |
+
ENV TRANSFORMERS_CACHE=/code/app/.cache
|
27 |
+
ENV CUDA_VISIBLE_DEVICES=0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
# Expose the port the app runs on
|
30 |
+
EXPOSE 8000
|
31 |
|
32 |
# Command to run the application
|
33 |
+
CMD ["python3", "-m", "app.main"]
|
main/__init__.py
DELETED
File without changes
|
main/env_template
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
# Service URLs Configuration
|
2 |
-
LLM_ENGINE_URL=http://localhost:8001
|
3 |
-
RAG_ENGINE_URL=http://localhost:8002
|
4 |
-
|
5 |
-
# LLM Engine Server Configuration
|
6 |
-
LLM_ENGINE_HOST=0.0.0.0
|
7 |
-
LLM_ENGINE_PORT=8001
|
8 |
-
|
9 |
-
# RAG Engine Server Configuration (if running locally)
|
10 |
-
RAG_ENGINE_HOST=0.0.0.0
|
11 |
-
RAG_ENGINE_PORT=8002
|
12 |
-
|
13 |
-
# Base Paths Configuration
|
14 |
-
BAS_MODEL_PATH=/path/to/your/model
|
15 |
-
BAS_RESOURCES=/path/to/resources
|
16 |
-
|
17 |
-
# CUDA Memory Management
|
18 |
-
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128,garbage_collection_threshold:0.8,expandable_segments:True
|
19 |
-
|
20 |
-
# Other memory-related settings
|
21 |
-
CUDA_LAUNCH_BLOCKING=0
|
22 |
-
CUDA_VISIBLE_DEVICES=0
|
23 |
-
|
24 |
-
# Logging Configuration
|
25 |
-
LOG_LEVEL=INFO # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
26 |
-
|
27 |
-
# GPU Configuration (optional)
|
28 |
-
# CUDA_VISIBLE_DEVICES=0,1 # Specify which GPUs to use
|
29 |
-
|
30 |
-
# Memory Configuration (optional)
|
31 |
-
# MAX_GPU_MEMORY=16Gi # Maximum GPU memory to use
|
32 |
-
# MAX_CPU_MEMORY=32Gi # Maximum CPU memory to use
|
33 |
-
|
34 |
-
# Security (if needed)
|
35 |
-
# API_KEY=your-api-key-here
|
36 |
-
# SSL_CERT_PATH=/path/to/cert
|
37 |
-
# SSL_KEY_PATH=/path/to/key
|
38 |
-
|
39 |
-
# Development Settings
|
40 |
-
# DEBUG=True # Enable debug mode
|
41 |
-
# RELOAD=False # Enable auto-reload for development
|
42 |
-
|
43 |
-
# Model Default Parameters (optional)
|
44 |
-
# DEFAULT_MAX_NEW_TOKENS=50
|
45 |
-
# DEFAULT_TEMPERATURE=1.0
|
46 |
-
# DEFAULT_TOP_K=50
|
47 |
-
# DEFAULT_TOP_P=1.0
|
48 |
-
|
49 |
-
# Cache Settings (optional)
|
50 |
-
# CACHE_DIR=/path/to/cache
|
51 |
-
# MAX_CACHE_SIZE=10Gi
|
52 |
-
|
53 |
-
# Monitoring (optional)
|
54 |
-
# ENABLE_METRICS=True
|
55 |
-
# PROMETHEUS_PORT=9090
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/main.py
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
from fastapi import FastAPI
|
2 |
-
from fastapi.middleware.cors import CORSMiddleware
|
3 |
-
import logging
|
4 |
-
import os
|
5 |
-
import uvicorn
|
6 |
-
from .routes import router
|
7 |
-
|
8 |
-
# Set up logging
|
9 |
-
logging.basicConfig(level=logging.INFO)
|
10 |
-
logger = logging.getLogger(__name__)
|
11 |
-
|
12 |
-
# Initialize FastAPI with simplified configuration
|
13 |
-
app = FastAPI(
|
14 |
-
title="LLM Engine Service",
|
15 |
-
docs_url="/docs",
|
16 |
-
redoc_url="/redoc",
|
17 |
-
openapi_url="/openapi.json"
|
18 |
-
)
|
19 |
-
|
20 |
-
# Add CORS middleware
|
21 |
-
app.add_middleware(
|
22 |
-
CORSMiddleware,
|
23 |
-
allow_origins=["*"],
|
24 |
-
allow_credentials=True,
|
25 |
-
allow_methods=["*"],
|
26 |
-
allow_headers=["*"],
|
27 |
-
)
|
28 |
-
|
29 |
-
# Include the router from routes.py
|
30 |
-
app.include_router(router)
|
31 |
-
|
32 |
-
def main():
|
33 |
-
# Load environment variables or configuration here
|
34 |
-
host = os.getenv("LLM_ENGINE_HOST", "0.0.0.0")
|
35 |
-
port = int(os.getenv("LLM_ENGINE_PORT", "7860")) # Default to 7860 for Spaces
|
36 |
-
|
37 |
-
# Log startup information
|
38 |
-
logger.info(f"Starting LLM Engine service on {host}:{port}, or: ")
|
39 |
-
logger.info("Available endpoints:")
|
40 |
-
logger.info(" - /")
|
41 |
-
logger.info(" - /health")
|
42 |
-
logger.info(" - /models")
|
43 |
-
logger.info(" - /initialize")
|
44 |
-
logger.info(" - /generate")
|
45 |
-
logger.info(" - /generate/stream")
|
46 |
-
logger.info(" - /download")
|
47 |
-
logger.info(" - /convert")
|
48 |
-
logger.info(" - /docs")
|
49 |
-
logger.info(" - /redoc")
|
50 |
-
logger.info(" - /openapi.json")
|
51 |
-
|
52 |
-
# Start the server
|
53 |
-
uvicorn.run(
|
54 |
-
app,
|
55 |
-
host=host,
|
56 |
-
port=port,
|
57 |
-
log_level="info"
|
58 |
-
)
|
59 |
-
|
60 |
-
if __name__ == "__main__":
|
61 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/routes.py
DELETED
@@ -1,419 +0,0 @@
|
|
1 |
-
|
2 |
-
from fastapi import APIRouter, HTTPException
|
3 |
-
from fastapi.responses import StreamingResponse
|
4 |
-
from pydantic import BaseModel, Field
|
5 |
-
from typing import Optional, Union, AsyncGenerator, List
|
6 |
-
import torch
|
7 |
-
import logging
|
8 |
-
from pathlib import Path
|
9 |
-
from litgpt.api import LLM
|
10 |
-
from litgpt.scripts.download import download_from_hub
|
11 |
-
from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint
|
12 |
-
import json
|
13 |
-
import asyncio
|
14 |
-
|
15 |
-
# Set up logging
|
16 |
-
logger = logging.getLogger(__name__)
|
17 |
-
|
18 |
-
# Create router instance
|
19 |
-
router = APIRouter()
|
20 |
-
|
21 |
-
# Global variable to store the LLM instance
|
22 |
-
llm_instance = None
|
23 |
-
|
24 |
-
class InitializeRequest(BaseModel):
|
25 |
-
"""Configuration for model initialization including model path"""
|
26 |
-
mode: str = Field(default="cpu", description="Execution mode ('cpu' or 'gpu')")
|
27 |
-
precision: Optional[str] = Field(None, description="Precision format (e.g., 'bf16-true', 'bf16-mixed')")
|
28 |
-
quantize: Optional[str] = Field(None, description="Quantization format (e.g., 'bnb.nf4')")
|
29 |
-
gpu_count: Union[str, int] = Field(default="auto", description="Number of GPUs to use or 'auto'")
|
30 |
-
model_path: str = Field(..., description="Path to the model relative to checkpoints directory")
|
31 |
-
|
32 |
-
class GenerateRequest(BaseModel):
|
33 |
-
"""Request parameters for text generation"""
|
34 |
-
prompt: str = Field(..., description="Input text prompt for generation")
|
35 |
-
max_new_tokens: int = Field(default=50, description="Maximum number of tokens to generate")
|
36 |
-
temperature: float = Field(default=1.0, description="Sampling temperature")
|
37 |
-
top_k: Optional[int] = Field(None, description="Top-k sampling parameter")
|
38 |
-
top_p: float = Field(default=1.0, description="Top-p sampling parameter")
|
39 |
-
return_as_token_ids: bool = Field(default=False, description="Whether to return token IDs instead of text")
|
40 |
-
stream: bool = Field(default=False, description="Whether to stream the response")
|
41 |
-
|
42 |
-
class StreamGenerateRequest(BaseModel):
|
43 |
-
"""Request parameters for streaming text generation"""
|
44 |
-
prompt: str = Field(..., description="Input text prompt for generation")
|
45 |
-
max_new_tokens: int = Field(default=50, description="Maximum number of tokens to generate")
|
46 |
-
temperature: float = Field(default=1.0, description="Sampling temperature")
|
47 |
-
top_k: Optional[int] = Field(None, description="Top-k sampling parameter")
|
48 |
-
top_p: float = Field(default=1.0, description="Top-p sampling parameter")
|
49 |
-
|
50 |
-
class DownloadModelRequest(BaseModel):
|
51 |
-
"""Request to download a model from HuggingFace"""
|
52 |
-
repo_id: str = Field(
|
53 |
-
...,
|
54 |
-
description="HuggingFace repository ID (e.g., 'huihui-ai/Llama-3.2-3B-Instruct-abliterated')"
|
55 |
-
)
|
56 |
-
model_name: str = Field(
|
57 |
-
...,
|
58 |
-
description="Model architecture name (e.g., 'Llama-3.2-3B-Instruct')"
|
59 |
-
)
|
60 |
-
access_token: Optional[str] = Field(
|
61 |
-
None,
|
62 |
-
description="HuggingFace access token for private models"
|
63 |
-
)
|
64 |
-
|
65 |
-
class ConvertModelRequest(BaseModel):
|
66 |
-
"""Request to convert a downloaded model"""
|
67 |
-
folder_path: str = Field(
|
68 |
-
...,
|
69 |
-
description="Path relative to checkpoints where model was downloaded"
|
70 |
-
)
|
71 |
-
model_name: str = Field(
|
72 |
-
...,
|
73 |
-
description="Model architecture name for conversion"
|
74 |
-
)
|
75 |
-
|
76 |
-
class ModelResponse(BaseModel):
|
77 |
-
"""Model information response"""
|
78 |
-
name: str = Field(..., description="Full model name including organization")
|
79 |
-
path: str = Field(..., description="Relative path in checkpoints directory")
|
80 |
-
downloaded: bool = Field(..., description="Whether the model files are downloaded")
|
81 |
-
converted: bool = Field(..., description="Whether the model is converted to LitGPT format")
|
82 |
-
has_safetensors: bool = Field(..., description="Whether safetensors files are present")
|
83 |
-
files: List[str] = Field(..., description="List of files in model directory")
|
84 |
-
|
85 |
-
class ModelsListResponse(BaseModel):
|
86 |
-
"""Response for listing models"""
|
87 |
-
models: List[ModelResponse] = Field(..., description="List of available models")
|
88 |
-
|
89 |
-
@router.post(
|
90 |
-
"/download",
|
91 |
-
response_model=dict,
|
92 |
-
summary="Download a model from HuggingFace Hub",
|
93 |
-
description="Downloads a model from HuggingFace to the LLM Engine's checkpoints directory",
|
94 |
-
response_description="Download status and location information"
|
95 |
-
)
|
96 |
-
async def download_model(request: DownloadModelRequest):
|
97 |
-
"""
|
98 |
-
Download a model from HuggingFace Hub.
|
99 |
-
|
100 |
-
- Downloads model files to the checkpoints directory
|
101 |
-
- Creates necessary subdirectories
|
102 |
-
- Handles authentication for private models
|
103 |
-
|
104 |
-
Returns:
|
105 |
-
A JSON object containing download status and path information
|
106 |
-
"""
|
107 |
-
try:
|
108 |
-
# Get the project root directory and construct paths
|
109 |
-
project_root = Path(__file__).parent.parent
|
110 |
-
checkpoints_dir = project_root / "checkpoints"
|
111 |
-
logger.info(f"Downloading model {request.repo_id} to {checkpoints_dir}")
|
112 |
-
|
113 |
-
download_from_hub(
|
114 |
-
repo_id=request.repo_id,
|
115 |
-
model_name=request.model_name,
|
116 |
-
access_token=request.access_token,
|
117 |
-
checkpoint_dir=checkpoints_dir,
|
118 |
-
tokenizer_only=False
|
119 |
-
)
|
120 |
-
|
121 |
-
return {
|
122 |
-
"status": "success",
|
123 |
-
"message": f"Model downloaded to {checkpoints_dir / request.repo_id}",
|
124 |
-
"path": str(request.repo_id)
|
125 |
-
}
|
126 |
-
|
127 |
-
except Exception as e:
|
128 |
-
logger.error(f"Error downloading model: {str(e)}")
|
129 |
-
raise HTTPException(status_code=500, detail=f"Error downloading model: {str(e)}")
|
130 |
-
|
131 |
-
@router.post(
|
132 |
-
"/convert",
|
133 |
-
response_model=dict,
|
134 |
-
summary="Convert a model to LitGPT format",
|
135 |
-
description="Converts a downloaded model to the LitGPT format required for inference",
|
136 |
-
response_description="Conversion status and location information"
|
137 |
-
)
|
138 |
-
async def convert_model(request: ConvertModelRequest):
|
139 |
-
"""
|
140 |
-
Convert a downloaded model to LitGPT format.
|
141 |
-
|
142 |
-
- Converts model files to LitGPT's format
|
143 |
-
- Creates lit_model.pth file
|
144 |
-
- Maintains original files
|
145 |
-
|
146 |
-
Returns:
|
147 |
-
A JSON object containing conversion status and path information
|
148 |
-
"""
|
149 |
-
try:
|
150 |
-
project_root = Path(__file__).parent.parent
|
151 |
-
checkpoints_dir = project_root / "checkpoints"
|
152 |
-
model_dir = checkpoints_dir / request.folder_path
|
153 |
-
|
154 |
-
if not model_dir.exists():
|
155 |
-
raise HTTPException(
|
156 |
-
status_code=404,
|
157 |
-
detail=f"Model directory not found: {request.folder_path}"
|
158 |
-
)
|
159 |
-
|
160 |
-
logger.info(f"Converting model in {model_dir}")
|
161 |
-
convert_hf_checkpoint(
|
162 |
-
checkpoint_dir=model_dir,
|
163 |
-
model_name=request.model_name
|
164 |
-
)
|
165 |
-
|
166 |
-
return {
|
167 |
-
"status": "success",
|
168 |
-
"message": f"Model converted successfully",
|
169 |
-
"path": str(request.folder_path)
|
170 |
-
}
|
171 |
-
|
172 |
-
except Exception as e:
|
173 |
-
logger.error(f"Error converting model: {str(e)}")
|
174 |
-
raise HTTPException(status_code=500, detail=f"Error converting model: {str(e)}")
|
175 |
-
|
176 |
-
@router.get(
|
177 |
-
"/models",
|
178 |
-
response_model=ModelsListResponse,
|
179 |
-
summary="List available models",
|
180 |
-
description="Lists all models in the checkpoints directory with their status",
|
181 |
-
response_description="List of models with their details and status"
|
182 |
-
)
|
183 |
-
async def list_models():
|
184 |
-
"""
|
185 |
-
List all models in the checkpoints directory.
|
186 |
-
|
187 |
-
Returns:
|
188 |
-
A JSON object containing:
|
189 |
-
- List of models
|
190 |
-
- Each model's download status
|
191 |
-
- Each model's conversion status
|
192 |
-
- Available files for each model
|
193 |
-
"""
|
194 |
-
try:
|
195 |
-
project_root = Path(__file__).parent.parent
|
196 |
-
checkpoints_dir = project_root / "checkpoints"
|
197 |
-
models = []
|
198 |
-
|
199 |
-
if checkpoints_dir.exists():
|
200 |
-
for org_dir in checkpoints_dir.iterdir():
|
201 |
-
if org_dir.is_dir():
|
202 |
-
for model_dir in org_dir.iterdir():
|
203 |
-
if model_dir.is_dir():
|
204 |
-
files = [f.name for f in model_dir.iterdir()]
|
205 |
-
has_safetensors = any(f.endswith('.safetensors') for f in files)
|
206 |
-
has_lit_model = 'lit_model.pth' in files
|
207 |
-
|
208 |
-
model_info = ModelResponse(
|
209 |
-
name=f"{org_dir.name}/{model_dir.name}",
|
210 |
-
path=str(model_dir.relative_to(checkpoints_dir)),
|
211 |
-
downloaded=True,
|
212 |
-
converted=has_lit_model,
|
213 |
-
has_safetensors=has_safetensors,
|
214 |
-
files=files
|
215 |
-
)
|
216 |
-
models.append(model_info)
|
217 |
-
|
218 |
-
return ModelsListResponse(models=models)
|
219 |
-
|
220 |
-
except Exception as e:
|
221 |
-
logger.error(f"Error listing models: {str(e)}")
|
222 |
-
raise HTTPException(status_code=500, detail=f"Error listing models: {str(e)}")
|
223 |
-
|
224 |
-
@router.post("/initialize")
|
225 |
-
async def initialize_model(request: InitializeRequest):
|
226 |
-
"""
|
227 |
-
Initialize the LLM model with specified configuration.
|
228 |
-
"""
|
229 |
-
global llm_instance
|
230 |
-
|
231 |
-
try:
|
232 |
-
# Get the project root directory (where main.py is located)
|
233 |
-
project_root = Path(__file__).parent.parent
|
234 |
-
checkpoints_dir = project_root / "checkpoints"
|
235 |
-
logger.info(f"Checkpoint dir is: {checkpoints_dir}")
|
236 |
-
|
237 |
-
# For LitGPT downloaded models, path includes organization
|
238 |
-
if "/" in request.model_path:
|
239 |
-
# e.g., "mistralai/Mistral-7B-Instruct-v0.3"
|
240 |
-
org, model_name = request.model_path.split("/")
|
241 |
-
model_path = str(checkpoints_dir / org / model_name)
|
242 |
-
else:
|
243 |
-
# Fallback for direct model paths
|
244 |
-
model_path = str(checkpoints_dir / request.model_path)
|
245 |
-
|
246 |
-
logger.info(f"Using model path: {model_path}")
|
247 |
-
|
248 |
-
# Load the model
|
249 |
-
llm_instance = LLM.load(
|
250 |
-
model=model_path,
|
251 |
-
distribute=None if request.precision or request.quantize else "auto"
|
252 |
-
)
|
253 |
-
|
254 |
-
# If manual distribution is needed
|
255 |
-
logger.info("Distributing model")
|
256 |
-
if request.precision or request.quantize:
|
257 |
-
llm_instance.distribute(
|
258 |
-
accelerator="cuda" if request.mode == "gpu" else "cpu",
|
259 |
-
devices=request.gpu_count,
|
260 |
-
precision=request.precision,
|
261 |
-
quantize=request.quantize
|
262 |
-
)
|
263 |
-
|
264 |
-
logger.info(
|
265 |
-
f"Model initialized successfully with config:\n"
|
266 |
-
f"Mode: {request.mode}\n"
|
267 |
-
f"Precision: {request.precision}\n"
|
268 |
-
f"Quantize: {request.quantize}\n"
|
269 |
-
f"GPU Count: {request.gpu_count}\n"
|
270 |
-
f"Model Path: {model_path}\n"
|
271 |
-
f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
|
272 |
-
f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
|
273 |
-
)
|
274 |
-
|
275 |
-
return {"success": True, "message": "Model initialized successfully"}
|
276 |
-
|
277 |
-
except Exception as e:
|
278 |
-
logger.error(f"Error initializing model: {str(e)}")
|
279 |
-
# Print detailed memory statistics on failure
|
280 |
-
logger.error(f"GPU Memory Stats:\n"
|
281 |
-
f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
|
282 |
-
f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
|
283 |
-
f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
|
284 |
-
raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}")
|
285 |
-
|
286 |
-
@router.post("/generate")
|
287 |
-
async def generate(request: GenerateRequest):
|
288 |
-
"""
|
289 |
-
Generate text using the initialized model.
|
290 |
-
"""
|
291 |
-
global llm_instance
|
292 |
-
|
293 |
-
if llm_instance is None:
|
294 |
-
raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.")
|
295 |
-
|
296 |
-
try:
|
297 |
-
if request.stream:
|
298 |
-
raise HTTPException(
|
299 |
-
status_code=400,
|
300 |
-
detail="Streaming is not currently supported through the API"
|
301 |
-
)
|
302 |
-
|
303 |
-
generated_text = llm_instance.generate(
|
304 |
-
prompt=request.prompt,
|
305 |
-
max_new_tokens=request.max_new_tokens,
|
306 |
-
temperature=request.temperature,
|
307 |
-
top_k=request.top_k,
|
308 |
-
top_p=request.top_p,
|
309 |
-
return_as_token_ids=request.return_as_token_ids,
|
310 |
-
stream=False # Force stream to False for now
|
311 |
-
)
|
312 |
-
|
313 |
-
response = {
|
314 |
-
"generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(),
|
315 |
-
"metadata": {
|
316 |
-
"prompt": request.prompt,
|
317 |
-
"max_new_tokens": request.max_new_tokens,
|
318 |
-
"temperature": request.temperature,
|
319 |
-
"top_k": request.top_k,
|
320 |
-
"top_p": request.top_p
|
321 |
-
}
|
322 |
-
}
|
323 |
-
|
324 |
-
return response
|
325 |
-
|
326 |
-
except Exception as e:
|
327 |
-
logger.error(f"Error generating text: {str(e)}")
|
328 |
-
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
|
329 |
-
|
330 |
-
@router.post("/generate/stream")
|
331 |
-
async def generate_stream(request: StreamGenerateRequest):
|
332 |
-
"""
|
333 |
-
Generate text using the initialized model with streaming response.
|
334 |
-
Returns a StreamingResponse that yields JSON-formatted chunks of text.
|
335 |
-
"""
|
336 |
-
global llm_instance
|
337 |
-
|
338 |
-
if llm_instance is None:
|
339 |
-
raise HTTPException(
|
340 |
-
status_code=400,
|
341 |
-
detail="Model not initialized. Call /initialize first."
|
342 |
-
)
|
343 |
-
|
344 |
-
async def event_generator() -> AsyncGenerator[str, None]:
|
345 |
-
try:
|
346 |
-
# Start the generation with streaming enabled
|
347 |
-
for token in llm_instance.generate(
|
348 |
-
prompt=request.prompt,
|
349 |
-
max_new_tokens=request.max_new_tokens,
|
350 |
-
temperature=request.temperature,
|
351 |
-
top_k=request.top_k,
|
352 |
-
top_p=request.top_p,
|
353 |
-
stream=True # Enable streaming
|
354 |
-
):
|
355 |
-
# Create a JSON response for each token
|
356 |
-
chunk = {
|
357 |
-
"token": token,
|
358 |
-
"metadata": {
|
359 |
-
"prompt": request.prompt,
|
360 |
-
"is_finished": False
|
361 |
-
}
|
362 |
-
}
|
363 |
-
# Format as SSE data
|
364 |
-
yield f"data: {json.dumps(chunk)}\n\n"
|
365 |
-
|
366 |
-
# Small delay to prevent overwhelming the client
|
367 |
-
await asyncio.sleep(0.01)
|
368 |
-
|
369 |
-
# Send final message indicating completion
|
370 |
-
final_chunk = {
|
371 |
-
"token": "",
|
372 |
-
"metadata": {
|
373 |
-
"prompt": request.prompt,
|
374 |
-
"is_finished": True
|
375 |
-
}
|
376 |
-
}
|
377 |
-
yield f"data: {json.dumps(final_chunk)}\n\n"
|
378 |
-
|
379 |
-
except Exception as e:
|
380 |
-
logger.error(f"Error in stream generation: {str(e)}")
|
381 |
-
error_chunk = {
|
382 |
-
"error": str(e),
|
383 |
-
"metadata": {
|
384 |
-
"prompt": request.prompt,
|
385 |
-
"is_finished": True
|
386 |
-
}
|
387 |
-
}
|
388 |
-
yield f"data: {json.dumps(error_chunk)}\n\n"
|
389 |
-
|
390 |
-
return StreamingResponse(
|
391 |
-
event_generator(),
|
392 |
-
media_type="text/event-stream",
|
393 |
-
headers={
|
394 |
-
'Cache-Control': 'no-cache',
|
395 |
-
'Connection': 'keep-alive',
|
396 |
-
}
|
397 |
-
)
|
398 |
-
|
399 |
-
@router.get("/health")
|
400 |
-
async def health_check():
|
401 |
-
"""
|
402 |
-
Check if the service is running and model is loaded.
|
403 |
-
Returns status information including model details if loaded.
|
404 |
-
"""
|
405 |
-
global llm_instance
|
406 |
-
|
407 |
-
status = {
|
408 |
-
"status": "healthy",
|
409 |
-
"model_loaded": llm_instance is not None,
|
410 |
-
}
|
411 |
-
|
412 |
-
if llm_instance is not None:
|
413 |
-
logger.info(f"llm_instance is: {llm_instance}")
|
414 |
-
status["model_info"] = {
|
415 |
-
"model_path": llm_instance.config.name,
|
416 |
-
"device": str(next(llm_instance.model.parameters()).device)
|
417 |
-
}
|
418 |
-
|
419 |
-
return status
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,67 +1,57 @@
|
|
1 |
-
|
2 |
-
aiohttp==3.10.10
|
3 |
-
aiosignal==1.3.1
|
4 |
annotated-types==0.7.0
|
5 |
anyio==4.6.2.post1
|
6 |
-
attrs==24.2.0
|
7 |
bitsandbytes==0.44.1
|
8 |
certifi==2024.8.30
|
9 |
charset-normalizer==3.4.0
|
10 |
click==8.1.7
|
11 |
-
|
12 |
-
fastapi==0.109.0
|
13 |
filelock==3.16.1
|
14 |
-
frozenlist==1.5.0
|
15 |
fsspec==2024.10.0
|
16 |
h11==0.14.0
|
17 |
-
huggingface-hub==0.
|
18 |
idna==3.10
|
19 |
-
|
20 |
Jinja2==3.1.4
|
21 |
-
jsonargparse==4.32.1
|
22 |
-
lightning==2.4.0
|
23 |
-
lightning-utilities==0.11.8
|
24 |
-
litgpt==0.5.3
|
25 |
MarkupSafe==3.0.2
|
26 |
mpmath==1.3.0
|
27 |
-
multidict==6.1.0
|
28 |
networkx==3.4.2
|
29 |
-
numpy==1.
|
30 |
-
nvidia-cublas-cu12==12.
|
31 |
-
nvidia-cuda-cupti-cu12==12.
|
32 |
-
nvidia-cuda-nvrtc-cu12==12.
|
33 |
-
nvidia-cuda-runtime-cu12==12.
|
34 |
nvidia-cudnn-cu12==9.1.0.70
|
35 |
-
nvidia-cufft-cu12==11.
|
36 |
-
nvidia-curand-cu12==10.3.
|
37 |
-
nvidia-cusolver-cu12==11.
|
38 |
-
nvidia-cusparse-cu12==12.1.
|
39 |
-
nvidia-nccl-cu12==2.
|
40 |
-
nvidia-nvjitlink-cu12==12.
|
41 |
-
nvidia-nvtx-cu12==12.
|
42 |
-
packaging==24.
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
48 |
PyYAML==6.0.2
|
49 |
-
regex==2024.
|
50 |
requests==2.32.3
|
|
|
51 |
safetensors==0.4.5
|
52 |
-
setuptools==75.
|
53 |
sniffio==1.3.1
|
54 |
-
starlette==0.
|
55 |
-
sympy==1.13.
|
56 |
tokenizers==0.20.3
|
57 |
-
torch==2.
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
triton==3.0.0
|
62 |
-
typeshed_client==2.7.0
|
63 |
typing_extensions==4.12.2
|
64 |
urllib3==2.2.3
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
1 |
+
accelerate==1.1.1
|
|
|
|
|
2 |
annotated-types==0.7.0
|
3 |
anyio==4.6.2.post1
|
|
|
4 |
bitsandbytes==0.44.1
|
5 |
certifi==2024.8.30
|
6 |
charset-normalizer==3.4.0
|
7 |
click==8.1.7
|
8 |
+
fastapi==0.115.5
|
|
|
9 |
filelock==3.16.1
|
|
|
10 |
fsspec==2024.10.0
|
11 |
h11==0.14.0
|
12 |
+
huggingface-hub==0.26.3
|
13 |
idna==3.10
|
14 |
+
inquirerpy==0.3.4
|
15 |
Jinja2==3.1.4
|
|
|
|
|
|
|
|
|
16 |
MarkupSafe==3.0.2
|
17 |
mpmath==1.3.0
|
|
|
18 |
networkx==3.4.2
|
19 |
+
numpy==2.1.3
|
20 |
+
nvidia-cublas-cu12==12.4.5.8
|
21 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
22 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
23 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
24 |
nvidia-cudnn-cu12==9.1.0.70
|
25 |
+
nvidia-cufft-cu12==11.2.1.3
|
26 |
+
nvidia-curand-cu12==10.3.5.147
|
27 |
+
nvidia-cusolver-cu12==11.6.1.9
|
28 |
+
nvidia-cusparse-cu12==12.3.1.170
|
29 |
+
nvidia-nccl-cu12==2.21.5
|
30 |
+
nvidia-nvjitlink-cu12==12.4.127
|
31 |
+
nvidia-nvtx-cu12==12.4.127
|
32 |
+
packaging==24.2
|
33 |
+
pfzy==0.3.4
|
34 |
+
prompt_toolkit==3.0.48
|
35 |
+
psutil==6.1.0
|
36 |
+
pydantic==2.10.2
|
37 |
+
pydantic_core==2.27.1
|
38 |
+
python-dotenv==1.0.1
|
39 |
PyYAML==6.0.2
|
40 |
+
regex==2024.11.6
|
41 |
requests==2.32.3
|
42 |
+
router==0.1
|
43 |
safetensors==0.4.5
|
44 |
+
setuptools==75.6.0
|
45 |
sniffio==1.3.1
|
46 |
+
starlette==0.41.3
|
47 |
+
sympy==1.13.1
|
48 |
tokenizers==0.20.3
|
49 |
+
torch==2.5.1
|
50 |
+
tqdm==4.67.1
|
51 |
+
transformers==4.46.3
|
52 |
+
triton==3.1.0
|
|
|
|
|
53 |
typing_extensions==4.12.2
|
54 |
urllib3==2.2.3
|
55 |
+
utils==1.0.2
|
56 |
+
uvicorn==0.32.1
|
57 |
+
wcwidth==0.2.13
|