navjotk commited on
Commit
1fd4ca3
·
verified ·
1 Parent(s): 1405d66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -31
app.py CHANGED
@@ -1,45 +1,34 @@
1
  import gradio as gr
2
- from huggingface_hub import hf_hub_download
3
- import joblib
4
  import pandas as pd
5
- import json
6
-
7
- # Load model and config from Hugging Face
8
- repo_id = "abhishek/autotrain-iris-xgboost"
9
-
10
- model_path = hf_hub_download(repo_id=repo_id, filename="model.joblib")
11
- config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
12
 
13
- # Load the model
 
14
  model = joblib.load(model_path)
15
 
16
- # Load the config to get feature names
17
- with open(config_path, "r") as f:
18
- config = json.load(f)
19
-
20
- feature_names = config["features"]
21
-
22
- # Inference function
23
  def predict(sepal_length, sepal_width, petal_length, petal_width):
24
- input_df = pd.DataFrame([[
25
- sepal_length, sepal_width, petal_length, petal_width
26
- ]], columns=feature_names)
27
-
 
 
28
  prediction = model.predict(input_df)[0]
29
- return f"🌸 Predicted species: {prediction}"
30
 
31
  # Gradio interface
32
- demo = gr.Interface(
33
  fn=predict,
34
  inputs=[
35
- gr.Slider(4.0, 8.0, step=0.1, label="Sepal Length"),
36
- gr.Slider(2.0, 5.0, step=0.1, label="Sepal Width"),
37
- gr.Slider(1.0, 7.0, step=0.1, label="Petal Length"),
38
- gr.Slider(0.1, 3.0, step=0.1, label="Petal Width"),
39
  ],
40
- outputs=gr.Textbox(label="Prediction"),
41
- title="🌸 Iris Flower Classifier",
42
- description="Enter flower measurements to predict the species using a model trained with AutoTrain on Hugging Face.",
43
  )
44
 
45
- demo.launch()
 
1
  import gradio as gr
 
 
2
  import pandas as pd
3
+ import joblib
4
+ from huggingface_hub import hf_hub_download
 
 
 
 
 
5
 
6
+ # Download model from Hugging Face
7
+ model_path = hf_hub_download(repo_id="abhishek/autotrain-iris-xgboost", filename="model.joblib")
8
  model = joblib.load(model_path)
9
 
10
+ # Prediction function
 
 
 
 
 
 
11
  def predict(sepal_length, sepal_width, petal_length, petal_width):
12
+ input_df = pd.DataFrame([{
13
+ "feat_SepalLengthCm": sepal_length,
14
+ "feat_SepalWidthCm": sepal_width,
15
+ "feat_PetalLengthCm": petal_length,
16
+ "feat_PetalWidthCm": petal_width
17
+ }])
18
  prediction = model.predict(input_df)[0]
19
+ return prediction
20
 
21
  # Gradio interface
22
+ iface = gr.Interface(
23
  fn=predict,
24
  inputs=[
25
+ gr.Slider(4.0, 8.0, label="Sepal Length (cm)"),
26
+ gr.Slider(2.0, 5.0, label="Sepal Width (cm)"),
27
+ gr.Slider(1.0, 7.0, label="Petal Length (cm)"),
28
+ gr.Slider(0.1, 3.0, label="Petal Width (cm)")
29
  ],
30
+ outputs="text",
31
+ title="Iris Flower Classifier 🌸"
 
32
  )
33
 
34
+ iface.launch()