from fastapi import FastAPI, Request from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from transformers import pipeline from pydantic import BaseModel from typing import Optional, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig app = FastAPI() # Serve the static files app.mount("/static", StaticFiles(directory="static"), name="static") pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small") def t5(input: str) -> dict[str, str]: output = pipe_flan(input) return {"output": output[0].get("generated_text", "")} @app.post("/infer_t5") async def infer_endpoint(data: dict = Depends(parse_raw)): """Receive input and generate text.""" try: input_text = data.get("input") # Validate that the input is a string assert isinstance(input_text, str), "Input must be a string." if input_text is None: return {"error": "No input text detected."} result = t5(input_text) return result except AssertionError as e: return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400) @app.get("/infer_t5") def get_default_inference_endpoint(): return {"message": "Use POST method to submit input data"} # Load the MIKU model and tokenizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: # Attempt to load the model and tokenizer regularly model_config = AutoConfig.from_pretrained("miqudev/miqu-1-70b") model = AutoModelForCausalLM.from_pretrained("miqudev/miqu-1-70b", config=model_config).to(device) tokenizer = AutoTokenizer.from_pretrained("miqudev/miqu-1-70b") except Exception as e: print("[WARNING]: Failed to load model and tokenizer conventionally.") print(f"Exception: {e}") # Construct a dummy configuration object model_config = AutoConfig.from_pretrained("miqudev/miqu-1-70b", trust_remote_code=True) # Load the model using the constructed configuration model = AutoModelForCausalLM.from_pretrained("miqudev/miqu-1-70b", config=model_config).to(device) tokenizer = AutoTokenizer.from_pretrained("miqudev/miqu-1-70b") def miuk_answer(query: str) -> str: query_tokens = tokenizer.encode(query, return_tensors="pt") query_tokens = query_tokens.to(device) answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id) return tokenizer.decode(answer[:, 0]).replace(" ", "") @app.post("/infer_miku") async def infer_endpoint(data: dict = Depends(parse_raw)): """Receive input and generate text.""" try: input_text = data.get("input") # Validate that the input is a string assert isinstance(input_text, str), "Input must be a string." if input_text is None: return {"error": "No input text detected."} result = {"output": miuk_answer(input_text)} return result except AssertionError as e: return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400) @app.get("/infer_miku") def get_default_inference_endpoint(): return {"message": "Use POST method to submit input data"}