omeg15 / server.py
AImused's picture
Upload folder using huggingface_hub
4ccc3f1 verified
from fastapi import FastAPI, HTTPException
# from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
# from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# from pydantic import BaseModel
# from pydantic import BaseModel
import librosa
# import librosa
import torch
import base64
# import base64
import io
# import io
import logging
import numpy as np
# import numpy as np
# import numpy as np
from transformers import AutoModel, AutoTokenizer
# from transformers import AutoModel, AutoTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class AudioRequest(BaseModel):
audio_data: str
sample_rate: int
class AudioResponse(BaseModel):
audio_data: str
text: str = ""
# Model initialization status
INITIALIZATION_STATUS = {
"model_loaded": False,
"error": None
}
# Global model and tokenizer instances
class Model:
def __init__(self):
self.model = model = AutoModel.from_pretrained(
'./models/checkpoint',
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation='sdpa'
)
model = model.eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(
'./models/checkpoint',
trust_remote_code=True
)
# Initialize TTS
model.init_tts()
model.tts.float() # Convert TTS to float32 if needed
self.model_in_sr = 16000
self.model_out_sr = 24000
self.ref_audio, _ = librosa.load('./ref_audios/female_example.wav', sr=self.model_in_sr, mono=True) # load the reference audio
self.sys_prompt = model.get_sys_prompt(ref_audio=self.ref_audio, mode='audio_assistant', language='en')
# warmup
audio_data = librosa.load('./ref_audios/male_example.wav', sr=self.model_in_sr, mono=True)[0]
_ = self.inference(audio_data, self.model_in_sr)
def inference(self, audio_np, input_audio_sr):
if input_audio_sr != self.model_in_sr:
audio_np = librosa.resample(audio_np, orig_sr=input_audio_sr, target_sr=self.model_in_sr)
user_question = {'role': 'user', 'content': [audio_np]}
# round one
msgs = [self.sys_prompt, user_question]
res = self.model.chat(
msgs=msgs,
tokenizer=self.tokenizer,
sampling=True,
max_new_tokens=128,
use_tts_template=True,
generate_audio=True,
temperature=0.3,
)
audio = res["audio_wav"].cpu().numpy()
if self.model_out_sr != input_audio_sr:
audio = librosa.resample(audio, orig_sr=self.model_out_sr, target_sr=input_audio_sr)
return audio, res["text"]
def initialize_model():
"""Initialize the MiniCPM model"""
global model, INITIALIZATION_STATUS
try:
logger.info("Initializing model...")
model = Model()
INITIALIZATION_STATUS["model_loaded"] = True
logger.info("MiniCPM model initialized successfully")
return True
except Exception as e:
INITIALIZATION_STATUS["error"] = str(e)
logger.error(f"Failed to initialize model: {e}")
return False
@app.on_event("startup")
async def startup_event():
"""Initialize model on startup"""
initialize_model()
@app.get("/api/v1/health")
def health_check():
"""Health check endpoint"""
status = {
"status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
"model_loaded": INITIALIZATION_STATUS["model_loaded"],
"error": INITIALIZATION_STATUS["error"]
}
return status
@app.post("/api/v1/inference")
async def inference(request: AudioRequest) -> AudioResponse:
"""Run inference with MiniCPM model"""
if not INITIALIZATION_STATUS["model_loaded"]:
raise HTTPException(
status_code=503,
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
)
try:
# Decode audio data from base64
audio_bytes = base64.b64decode(request.audio_data)
audio_np = np.load(io.BytesIO(audio_bytes)).flatten()
# Generate response
import time
start = time.time()
print(f"starting inference with audio length {audio_np.shape}")
audio_response, text_response = model.inference(audio_np, request.sample_rate)
print(f"inference took {time.time() - start} seconds")
# If we got audio, save it and encode to base64
buffer = io.BytesIO()
np.save(buffer, audio_response)
audio_b64 = base64.b64encode(buffer.getvalue()).decode()
return AudioResponse(
audio_data=audio_b64,
text=text_response
)
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
raise HTTPException(
status_code=500,
detail=str(e)
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)