Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing.text import tokenizer_from_json | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| import numpy as np | |
| import json | |
| from typing import Union, List | |
| app = FastAPI() | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| def load_model_and_tokenizer(): | |
| global model, tokenizer | |
| try: | |
| # Load model | |
| model = load_model('news_classifier.h5') | |
| # Load tokenizer - fixing the JSON handling | |
| with open('tokenizer.json', 'r') as f: | |
| tokenizer_json = f.read() # Read as string | |
| tokenizer = tokenizer_from_json(tokenizer_json) # Pass the string directly | |
| except Exception as e: | |
| print(f"Error loading model or tokenizer: {str(e)}") | |
| raise e | |
| # Load on startup | |
| load_model_and_tokenizer() | |
| class PredictionInput(BaseModel): | |
| text: Union[str, List[str]] | |
| class PredictionOutput(BaseModel): | |
| label: str | |
| score: float | |
| def read_root(): | |
| return { | |
| "message": "News Source Classifier API", | |
| "model_type": "LSTM", | |
| "version": "1.0", | |
| "status": "ready" if model and tokenizer else "not_loaded" | |
| } | |
| async def predict(input_data: PredictionInput): | |
| if not model or not tokenizer: | |
| try: | |
| load_model_and_tokenizer() | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| try: | |
| # Handle both single string and list inputs | |
| texts = input_data.text if isinstance(input_data.text, list) else [input_data.text] | |
| # Preprocess | |
| sequences = tokenizer.texts_to_sequences(texts) | |
| padded = pad_sequences(sequences, maxlen=41) # Match your model's input length | |
| # Get predictions | |
| predictions = model.predict(padded, verbose=0) | |
| # Process results | |
| results = [] | |
| for pred in predictions: | |
| # Reversed categorization logic | |
| label = "nbc" if pred[1] > 0.5 else "foxnews" | |
| score = float(pred[1] if label == "nbc" else 1 - pred[1]) | |
| results.append({ | |
| "label": label, | |
| "score": score | |
| }) | |
| # Return single result if input was single string | |
| return results[0] if isinstance(input_data.text, str) else results | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def reload_model(): | |
| try: | |
| load_model_and_tokenizer() | |
| return {"message": "Model reloaded successfully"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |