Jiahuita commited on
Commit
50fd5a6
1 Parent(s): cec65b9

error changes

Browse files
Files changed (2) hide show
  1. app.py +50 -18
  2. requirements.txt +0 -1
app.py CHANGED
@@ -4,16 +4,29 @@ from pydantic import BaseModel
4
  from tensorflow.keras.models import load_model
5
  from tensorflow.keras.preprocessing.text import tokenizer_from_json
6
  from tensorflow.keras.preprocessing.sequence import pad_sequences
 
7
  import json
8
  from typing import Union, List
9
 
10
  app = FastAPI()
11
 
12
- # Load model and tokenizer
13
- model = load_model('news_classifier.h5')
14
- with open('tokenizer.json', 'r') as f:
15
- tokenizer_data = json.load(f)
16
- tokenizer = tokenizer_from_json(tokenizer_data)
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class PredictionInput(BaseModel):
19
  text: Union[str, List[str]]
@@ -22,35 +35,54 @@ class PredictionOutput(BaseModel):
22
  label: str
23
  score: float
24
 
25
- @app.post("/predict")
 
 
 
 
 
 
 
 
 
26
  async def predict(input_data: PredictionInput):
 
 
 
 
 
 
27
  try:
28
- # Convert input to list if it's a single string
29
  texts = input_data.text if isinstance(input_data.text, list) else [input_data.text]
30
 
31
  # Preprocess
32
  sequences = tokenizer.texts_to_sequences(texts)
33
- padded = pad_sequences(sequences, maxlen=41) # Use your model's expected input length
34
 
35
- # Predict
36
- predictions = model.predict(padded)
37
 
38
- # Format results
39
  results = []
40
  for pred in predictions:
41
- score = float(pred[1]) # Assuming binary classification
42
- label = "foxnews" if score > 0.5 else "nbc"
43
  results.append({
44
  "label": label,
45
- "score": score if label == "foxnews" else 1 - score
46
  })
47
 
48
  # Return single result if input was single string
49
  return results[0] if isinstance(input_data.text, str) else results
50
-
51
  except Exception as e:
52
  raise HTTPException(status_code=500, detail=str(e))
53
 
54
- @app.get("/")
55
- async def root():
56
- return {"message": "News Classifier API is running"}
 
 
 
 
 
4
  from tensorflow.keras.models import load_model
5
  from tensorflow.keras.preprocessing.text import tokenizer_from_json
6
  from tensorflow.keras.preprocessing.sequence import pad_sequences
7
+ import numpy as np
8
  import json
9
  from typing import Union, List
10
 
11
  app = FastAPI()
12
 
13
+ # Global variables for model and tokenizer
14
+ model = None
15
+ tokenizer = None
16
+
17
+ def load_model_and_tokenizer():
18
+ global model, tokenizer
19
+ try:
20
+ model = load_model('news_classifier.h5')
21
+ with open('tokenizer.json', 'r') as f:
22
+ tokenizer_data = json.load(f)
23
+ tokenizer = tokenizer_from_json(tokenizer_data)
24
+ except Exception as e:
25
+ print(f"Error loading model or tokenizer: {str(e)}")
26
+ raise e
27
+
28
+ # Load on startup
29
+ load_model_and_tokenizer()
30
 
31
  class PredictionInput(BaseModel):
32
  text: Union[str, List[str]]
 
35
  label: str
36
  score: float
37
 
38
+ @app.get("/")
39
+ def read_root():
40
+ return {
41
+ "message": "News Source Classifier API",
42
+ "model_type": "LSTM",
43
+ "version": "1.0",
44
+ "status": "ready" if model and tokenizer else "not_loaded"
45
+ }
46
+
47
+ @app.post("/predict", response_model=Union[PredictionOutput, List[PredictionOutput]])
48
  async def predict(input_data: PredictionInput):
49
+ if not model or not tokenizer:
50
+ try:
51
+ load_model_and_tokenizer()
52
+ except Exception as e:
53
+ raise HTTPException(status_code=500, detail="Model not loaded")
54
+
55
  try:
56
+ # Handle both single string and list inputs
57
  texts = input_data.text if isinstance(input_data.text, list) else [input_data.text]
58
 
59
  # Preprocess
60
  sequences = tokenizer.texts_to_sequences(texts)
61
+ padded = pad_sequences(sequences, maxlen=41) # Match your model's input length
62
 
63
+ # Get predictions
64
+ predictions = model.predict(padded, verbose=0)
65
 
66
+ # Process results
67
  results = []
68
  for pred in predictions:
69
+ label = "foxnews" if pred[1] > 0.5 else "nbc"
70
+ score = float(pred[1] if label == "foxnews" else 1 - pred[1])
71
  results.append({
72
  "label": label,
73
+ "score": score
74
  })
75
 
76
  # Return single result if input was single string
77
  return results[0] if isinstance(input_data.text, str) else results
78
+
79
  except Exception as e:
80
  raise HTTPException(status_code=500, detail=str(e))
81
 
82
+ @app.post("/reload")
83
+ async def reload_model():
84
+ try:
85
+ load_model_and_tokenizer()
86
+ return {"message": "Model reloaded successfully"}
87
+ except Exception as e:
88
+ raise HTTPException(status_code=500, detail=str(e))
requirements.txt CHANGED
@@ -3,5 +3,4 @@ fastapi>=0.68.0
3
  uvicorn>=0.15.0
4
  pydantic>=1.8.2
5
  numpy>=1.19.2
6
- scikit-learn>=0.24.2
7
  python-multipart
 
3
  uvicorn>=0.15.0
4
  pydantic>=1.8.2
5
  numpy>=1.19.2
 
6
  python-multipart