Clement Vachet commited on
Commit
9ecca49
·
1 Parent(s): 03a5b00

Improve code based on pylint and black suggestions

Browse files
app.py CHANGED
@@ -1,14 +1,23 @@
 
 
 
 
1
  import os
 
2
  import requests
3
  import gradio as gr
4
- from classification.classifier import Classifier
5
  from dotenv import load_dotenv, find_dotenv
6
- import json
 
 
 
 
7
 
8
  # Initialize API URLs from env file or global settings
9
  def retrieve_api():
 
10
 
11
- env_path = find_dotenv('config_api.env')
12
  if env_path:
13
  load_dotenv(dotenv_path=env_path)
14
  print("config_api.env file loaded successfully.")
@@ -19,31 +28,35 @@ def retrieve_api():
19
  global AWS_API
20
  AWS_API = os.getenv("AWS_API", default="http://localhost:8000")
21
 
 
22
  def initialize_classifier():
23
- global cls
 
24
  cls = Classifier()
 
25
 
26
 
27
  def predict_class_local(sepl, sepw, petl, petw):
 
 
28
  data = list(map(float, [sepl, sepw, petl, petw]))
 
29
  results = cls.load_and_test(data)
30
  return results
31
 
32
 
33
  def predict_class_aws(sepl, sepw, petl, petw):
 
 
34
  if AWS_API == "http://localhost:8080":
35
- API_endpoint = AWS_API + "/2015-03-31/functions/function/invocations"
36
  else:
37
- API_endpoint = AWS_API + "/test/classify"
38
 
39
  data = list(map(float, [sepl, sepw, petl, petw]))
40
- json_object = {
41
- "features": [
42
- data
43
- ]
44
- }
45
 
46
- response = requests.post(API_endpoint, json=json_object)
47
  if response.status_code == 200:
48
  # Process the response
49
  response_json = response.json()
@@ -54,11 +67,14 @@ def predict_class_aws(sepl, sepw, petl, petw):
54
  return results_dict
55
 
56
 
57
- def predict(sepl, sepw, petl, petw, type):
58
- print("type: ", type)
59
- if type == "Local":
 
 
 
60
  results = predict_class_local(sepl, sepw, petl, petw)
61
- elif type == "AWS API":
62
  results = predict_class_aws(sepl, sepw, petl, petw)
63
 
64
  prediction = results["predictions"][0]
@@ -69,30 +85,41 @@ def predict(sepl, sepw, petl, petw, type):
69
 
70
  # Define the Gradio interface
71
  def user_interface():
 
 
 
 
 
 
 
 
 
 
 
72
  with gr.Blocks() as demo:
73
  gr.Markdown("# IRIS classification task - use of AWS Lambda")
74
- gr.Markdown(
75
- """
76
- Aims: Categorization of different species of iris flowers (Setosa, Versicolor, and Virginica)
77
- based on measurements of physical characteristics (sepals and petals).
78
-
79
- Notes: This web application uses two types of predictions:
80
- - local prediction (direct source code)
81
- - cloud prediction via an AWS API (i.e. use of ECR, Lambda function and API Gateway) to run the machine learning model.
82
-
83
- """
84
- )
85
 
86
  with gr.Row():
87
  with gr.Column():
88
  with gr.Group():
89
- gr_sepl = gr.Slider(minimum=4.0, maximum=8.0, step=0.1, label="Sepal Length (in cm)")
90
- gr_sepw = gr.Slider(minimum=2.0, maximum=5.0, step=0.1, label="Sepal Width (in cm)")
91
- gr_petl = gr.Slider(minimum=1.0, maximum=7.0, step=0.1, label="Petal Length (in cm)")
92
- gr_petw = gr.Slider(minimum=0.1, maximum=2.8, step=0.1, label="Petal Width (in cm)")
 
 
 
 
 
 
 
 
93
  with gr.Column():
94
  with gr.Row():
95
- gr_type = gr.Radio(["Local", "AWS API"], value="Local", label="Prediction type")
 
 
96
  with gr.Row():
97
  gr_output = gr.Textbox(label="Prediction output")
98
 
@@ -100,12 +127,15 @@ def user_interface():
100
  submit_btn = gr.Button("Submit")
101
  clear_button = gr.ClearButton()
102
 
103
- submit_btn.click(fn=predict, inputs=[gr_sepl, gr_sepw, gr_petl, gr_petw, gr_type], outputs=[gr_output])
 
 
 
 
104
  clear_button.click(lambda: None, inputs=None, outputs=[gr_output], queue=False)
105
  demo.queue().launch(debug=True)
106
 
107
 
108
  if __name__ == "__main__":
109
  retrieve_api()
110
- initialize_classifier()
111
  user_interface()
 
1
+ """
2
+ Gradio web application
3
+ """
4
+
5
  import os
6
+ import json
7
  import requests
8
  import gradio as gr
 
9
  from dotenv import load_dotenv, find_dotenv
10
+
11
+ from classification.classifier import Classifier
12
+
13
+ AWS_API = None
14
+
15
 
16
  # Initialize API URLs from env file or global settings
17
  def retrieve_api():
18
+ """Initialize API URLs from env file or global settings"""
19
 
20
+ env_path = find_dotenv("config_api.env")
21
  if env_path:
22
  load_dotenv(dotenv_path=env_path)
23
  print("config_api.env file loaded successfully.")
 
28
  global AWS_API
29
  AWS_API = os.getenv("AWS_API", default="http://localhost:8000")
30
 
31
+
32
  def initialize_classifier():
33
+ """Initialize ML classifier"""
34
+
35
  cls = Classifier()
36
+ return cls
37
 
38
 
39
  def predict_class_local(sepl, sepw, petl, petw):
40
+ """ML prediction using direct source code - local"""
41
+
42
  data = list(map(float, [sepl, sepw, petl, petw]))
43
+ cls = initialize_classifier()
44
  results = cls.load_and_test(data)
45
  return results
46
 
47
 
48
  def predict_class_aws(sepl, sepw, petl, petw):
49
+ """ML prediction using AWS API endpoint"""
50
+
51
  if AWS_API == "http://localhost:8080":
52
+ api_endpoint = AWS_API + "/2015-03-31/functions/function/invocations"
53
  else:
54
+ api_endpoint = AWS_API + "/test/classify"
55
 
56
  data = list(map(float, [sepl, sepw, petl, petw]))
57
+ json_object = {"features": [data]}
 
 
 
 
58
 
59
+ response = requests.post(api_endpoint, json=json_object, timeout=60)
60
  if response.status_code == 200:
61
  # Process the response
62
  response_json = response.json()
 
67
  return results_dict
68
 
69
 
70
+ def predict(sepl, sepw, petl, petw, execution_type):
71
+ """ML prediction - local or via API endpoint"""
72
+
73
+ print("ML prediction type: ", execution_type)
74
+ results = None
75
+ if execution_type == "Local":
76
  results = predict_class_local(sepl, sepw, petl, petw)
77
+ elif execution_type == "AWS API":
78
  results = predict_class_aws(sepl, sepw, petl, petw)
79
 
80
  prediction = results["predictions"][0]
 
85
 
86
  # Define the Gradio interface
87
  def user_interface():
88
+ """Gradio application"""
89
+
90
+ description = """
91
+ Aims: Categorization of different species of iris flowers (Setosa, Versicolor, and Virginica)
92
+ based on measurements of physical characteristics (sepals and petals).
93
+
94
+ Notes: This web application uses two types of machine learning predictions:
95
+ - local prediction (direct source code)
96
+ - cloud prediction via an AWS API (i.e. use of ECR, Lambda function and API Gateway)
97
+ """
98
+
99
  with gr.Blocks() as demo:
100
  gr.Markdown("# IRIS classification task - use of AWS Lambda")
101
+ gr.Markdown(description)
 
 
 
 
 
 
 
 
 
 
102
 
103
  with gr.Row():
104
  with gr.Column():
105
  with gr.Group():
106
+ gr_sepl = gr.Slider(
107
+ minimum=4.0, maximum=8.0, step=0.1, label="Sepal Length (in cm)"
108
+ )
109
+ gr_sepw = gr.Slider(
110
+ minimum=2.0, maximum=5.0, step=0.1, label="Sepal Width (in cm)"
111
+ )
112
+ gr_petl = gr.Slider(
113
+ minimum=1.0, maximum=7.0, step=0.1, label="Petal Length (in cm)"
114
+ )
115
+ gr_petw = gr.Slider(
116
+ minimum=0.1, maximum=2.8, step=0.1, label="Petal Width (in cm)"
117
+ )
118
  with gr.Column():
119
  with gr.Row():
120
+ gr_execution_type = gr.Radio(
121
+ ["Local", "AWS API"], value="Local", label="Prediction type"
122
+ )
123
  with gr.Row():
124
  gr_output = gr.Textbox(label="Prediction output")
125
 
 
127
  submit_btn = gr.Button("Submit")
128
  clear_button = gr.ClearButton()
129
 
130
+ submit_btn.click(
131
+ fn=predict,
132
+ inputs=[gr_sepl, gr_sepw, gr_petl, gr_petw, gr_execution_type],
133
+ outputs=[gr_output],
134
+ )
135
  clear_button.click(lambda: None, inputs=None, outputs=[gr_output], queue=False)
136
  demo.queue().launch(debug=True)
137
 
138
 
139
  if __name__ == "__main__":
140
  retrieve_api()
 
141
  user_interface()
classification/classifier.py CHANGED
@@ -1,40 +1,51 @@
1
- # from sklearn.ensemble import AdaBoostClassifier
 
 
 
 
 
 
 
 
2
  from sklearn.tree import DecisionTreeClassifier
3
  from sklearn.datasets import load_iris
4
  from sklearn.model_selection import train_test_split
5
- import joblib
6
- import pandas as pd
7
- import os
8
- import numpy as np
9
 
10
  class Classifier:
 
 
11
  def __init__(self):
12
  pass
13
 
14
  def train_and_save(self):
 
15
  print("\nIRIS model training...")
16
  iris = load_iris()
17
- cart = DecisionTreeClassifier(max_depth = 3)
18
 
19
- X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.1, random_state=42)
20
- model = cart.fit(X_train, y_train)
 
 
21
 
22
- print(f"Model score: {cart.score(X_train, y_train):.3f}")
23
- print(f"Test Accuracy: {cart.score(X_test, y_test):.3f}")
24
 
25
  current_dir = os.path.dirname(os.path.abspath(__file__))
26
  parent_dir = os.path.dirname(current_dir)
27
  test_data_csv_path = os.path.join(parent_dir, "data", "test_data.csv")
28
 
29
- pd.concat([pd.DataFrame(X_test), pd.DataFrame(y_test, columns=['4'])], axis=1).to_csv(test_data_csv_path,
30
- index=False)
 
31
 
32
  model_path = os.path.join(parent_dir, "models", "model.pkl")
33
  joblib.dump(model, model_path)
34
  print(f"Model saved to {model_path}")
35
 
36
-
37
  def load_and_test(self, data):
 
38
  print("\nIRIS model prediction...")
39
 
40
  current_dir = os.path.dirname(os.path.abspath(__file__))
 
1
+ """
2
+ IRIS Classification - class definition
3
+ """
4
+
5
+ import os
6
+ import numpy as np
7
+ import pandas as pd
8
+ import joblib
9
+
10
  from sklearn.tree import DecisionTreeClassifier
11
  from sklearn.datasets import load_iris
12
  from sklearn.model_selection import train_test_split
13
+
 
 
 
14
 
15
  class Classifier:
16
+ """Classifier class - ML training and testing"""
17
+
18
  def __init__(self):
19
  pass
20
 
21
  def train_and_save(self):
22
+ """ML training and saving"""
23
  print("\nIRIS model training...")
24
  iris = load_iris()
25
+ cart = DecisionTreeClassifier(max_depth=3)
26
 
27
+ x_train, x_test, y_train, y_test = train_test_split(
28
+ iris.data, iris.target, test_size=0.1, random_state=42
29
+ )
30
+ model = cart.fit(x_train, y_train)
31
 
32
+ print(f"Model score: {cart.score(x_train, y_train):.3f}")
33
+ print(f"Test Accuracy: {cart.score(x_test, y_test):.3f}")
34
 
35
  current_dir = os.path.dirname(os.path.abspath(__file__))
36
  parent_dir = os.path.dirname(current_dir)
37
  test_data_csv_path = os.path.join(parent_dir, "data", "test_data.csv")
38
 
39
+ pd.concat([pd.DataFrame(x_test), pd.DataFrame(y_test, columns=["4"])], axis=1).to_csv(
40
+ test_data_csv_path, index=False
41
+ )
42
 
43
  model_path = os.path.join(parent_dir, "models", "model.pkl")
44
  joblib.dump(model, model_path)
45
  print(f"Model saved to {model_path}")
46
 
 
47
  def load_and_test(self, data):
48
+ "ML loading and testing"
49
  print("\nIRIS model prediction...")
50
 
51
  current_dir = os.path.dirname(os.path.abspath(__file__))
inference_api.py CHANGED
@@ -1,8 +1,11 @@
1
- import requests
2
- import io
 
 
 
3
  import json
4
  import argparse
5
- import sys
6
 
7
 
8
  # Default examples
@@ -13,11 +16,13 @@ def arg_parser():
13
  """Parse arguments"""
14
 
15
  # Create an ArgumentParser object
16
- parser = argparse.ArgumentParser(description='Object detection inference via API call')
17
  # Add arguments
18
- parser.add_argument('-u', '--url', type=str, help='URL to the server (with endpoint location)', required=True)
19
- parser.add_argument('-d', '--data', type=str, help='Input data', required=True)
20
- parser.add_argument('-v', '--verbose', action='store_true', help='Increase output verbosity')
 
 
21
  return parser
22
 
23
 
@@ -27,19 +32,19 @@ def main(args=None):
27
  args = arg_parser().parse_args(args)
28
  # Use the arguments
29
  if args.verbose:
30
- print(f'Input data: {args.data}')
31
- print(f'Input data type: {type(args.data)}')
32
 
33
  # Send request to API
34
- response = requests.post(args.url, json=json.loads(args.data))
35
 
36
  if response.status_code == 200:
37
  # Process the response
38
  processed_data = json.loads(response.content)
39
- print('processed_data', processed_data)
40
  else:
41
  print(f"Error: {response.status_code}")
42
 
43
 
44
  if __name__ == "__main__":
45
- sys.exit(main(sys.argv[1:]))
 
1
+ """
2
+ IRIS classification - command line inference via API
3
+ """
4
+
5
+ import sys
6
  import json
7
  import argparse
8
+ import requests
9
 
10
 
11
  # Default examples
 
16
  """Parse arguments"""
17
 
18
  # Create an ArgumentParser object
19
+ parser = argparse.ArgumentParser(description="IRIS classification inference via API call")
20
  # Add arguments
21
+ parser.add_argument(
22
+ "-u", "--url", type=str, help="URL to the server (with endpoint location)", required=True
23
+ )
24
+ parser.add_argument("-d", "--data", type=str, help="Input data", required=True)
25
+ parser.add_argument("-v", "--verbose", action="store_true", help="Increase output verbosity")
26
  return parser
27
 
28
 
 
32
  args = arg_parser().parse_args(args)
33
  # Use the arguments
34
  if args.verbose:
35
+ print(f"Input data: {args.data}")
36
+ print(f"Input data type: {type(args.data)}")
37
 
38
  # Send request to API
39
+ response = requests.post(args.url, json=json.loads(args.data), timeout=60)
40
 
41
  if response.status_code == 200:
42
  # Process the response
43
  processed_data = json.loads(response.content)
44
+ print("processed_data", processed_data)
45
  else:
46
  print(f"Error: {response.status_code}")
47
 
48
 
49
  if __name__ == "__main__":
50
+ sys.exit(main(sys.argv[1:]))
inference_direct.py CHANGED
@@ -1,5 +1,9 @@
1
- from classification.classifier import Classifier
 
 
 
2
  import json
 
3
 
4
 
5
  if __name__ == "__main__":
@@ -9,19 +13,16 @@ if __name__ == "__main__":
9
  cls.train_and_save()
10
 
11
  # Testing
12
- data = {
13
- "features": [[6.5, 3.0, 5.8, 2.2],[6.1, 2.8, 4.7, 1.2]]
14
- }
15
  features = data["features"]
16
  results = cls.load_and_test(features)
17
  print("results:", results)
18
 
19
  # Response similar to REST API call
20
  response = {
21
- 'statusCode': 200,
22
- 'body': json.dumps({
23
- 'predictions': results["predictions"],
24
- 'probabilities': results["probabilities"]
25
- })
26
  }
27
  print("Example REST API response: ", response)
 
1
+ """
2
+ Direct inference with hard-coded data
3
+ """
4
+
5
  import json
6
+ from classification.classifier import Classifier
7
 
8
 
9
  if __name__ == "__main__":
 
13
  cls.train_and_save()
14
 
15
  # Testing
16
+ data = {"features": [[6.5, 3.0, 5.8, 2.2], [6.1, 2.8, 4.7, 1.2]]}
 
 
17
  features = data["features"]
18
  results = cls.load_and_test(features)
19
  print("results:", results)
20
 
21
  # Response similar to REST API call
22
  response = {
23
+ "statusCode": 200,
24
+ "body": json.dumps(
25
+ {"predictions": results["predictions"], "probabilities": results["probabilities"]}
26
+ ),
 
27
  }
28
  print("Example REST API response: ", response)
lambda_function.py CHANGED
@@ -1,30 +1,39 @@
1
- from classification.classifier import Classifier
 
 
 
2
  import json
 
3
 
4
 
5
  cls = Classifier()
6
 
 
7
  # Lambda handler (proxy integration option unchecked on AWS API Gateway)
8
  def lambda_handler(event, context):
 
 
 
 
 
 
 
 
 
 
9
 
10
  try:
11
- features = event.get('features', {})
12
  if not features:
13
  raise ValueError("'features' key missing")
14
 
15
  response = cls.load_and_test(features)
16
  return {
17
- 'statusCode': 200,
18
- 'headers': {
19
- 'Content-Type': 'application/json'
20
- },
21
- 'body': json.dumps({
22
- 'predictions': response["predictions"],
23
- 'probabilities': response["probabilities"]
24
- })
25
  }
26
  except Exception as e:
27
- return {
28
- 'statusCode': 500,
29
- 'body': json.dumps({'error': str(e)})
30
- }
 
1
+ """
2
+ AWS Lambda function
3
+ """
4
+
5
  import json
6
+ from classification.classifier import Classifier
7
 
8
 
9
  cls = Classifier()
10
 
11
+
12
  # Lambda handler (proxy integration option unchecked on AWS API Gateway)
13
  def lambda_handler(event, context):
14
+ """
15
+ Lambda handler (proxy integration option unchecked on AWS API Gateway)
16
+
17
+ Args:
18
+ event (dict): The event that triggered the Lambda function.
19
+ context (LambdaContext): Information about the execution environment.
20
+
21
+ Returns:
22
+ dict: The response to be returned from the Lambda function.
23
+ """
24
 
25
  try:
26
+ features = event.get("features", {})
27
  if not features:
28
  raise ValueError("'features' key missing")
29
 
30
  response = cls.load_and_test(features)
31
  return {
32
+ "statusCode": 200,
33
+ "headers": {"Content-Type": "application/json"},
34
+ "body": json.dumps(
35
+ {"predictions": response["predictions"], "probabilities": response["probabilities"]}
36
+ ),
 
 
 
37
  }
38
  except Exception as e:
39
+ return {"statusCode": 500, "body": json.dumps({"error": str(e)})}
 
 
 
models/model.pkl CHANGED
Binary files a/models/model.pkl and b/models/model.pkl differ