navjotk commited on
Commit
834cdd5
·
verified ·
1 Parent(s): 88f930f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -1,34 +1,32 @@
1
  import gradio as gr
2
- import pandas as pd
3
  import joblib
 
4
  from huggingface_hub import hf_hub_download
5
 
6
- # Download model from Hugging Face
7
  model_path = hf_hub_download(repo_id="abhishek/autotrain-iris-xgboost", filename="model.joblib")
8
  model = joblib.load(model_path)
9
 
10
- # Prediction function
 
 
11
  def predict(sepal_length, sepal_width, petal_length, petal_width):
12
- input_df = pd.DataFrame([{
13
- "feat_SepalLengthCm": sepal_length,
14
- "feat_SepalWidthCm": sepal_width,
15
- "feat_PetalLengthCm": petal_length,
16
- "feat_PetalWidthCm": petal_width
17
- }])
18
- prediction = model.predict(input_df)[0]
19
- return prediction
20
 
21
  # Gradio interface
22
  iface = gr.Interface(
23
  fn=predict,
24
  inputs=[
25
- gr.Slider(4.0, 8.0, label="Sepal Length (cm)"),
26
- gr.Slider(2.0, 5.0, label="Sepal Width (cm)"),
27
- gr.Slider(1.0, 7.0, label="Petal Length (cm)"),
28
- gr.Slider(0.1, 3.0, label="Petal Width (cm)")
29
  ],
30
- outputs="text",
31
- title="Iris Flower Classifier 🌸"
 
32
  )
33
 
34
  iface.launch()
 
1
  import gradio as gr
 
2
  import joblib
3
+ import pandas as pd
4
  from huggingface_hub import hf_hub_download
5
 
6
+ # Download model from Hugging Face Hub
7
  model_path = hf_hub_download(repo_id="abhishek/autotrain-iris-xgboost", filename="model.joblib")
8
  model = joblib.load(model_path)
9
 
10
+ # Input labels expected by the model
11
+ feature_names = ['feat_SepalLengthCm', 'feat_SepalWidthCm', 'feat_PetalLengthCm', 'feat_PetalWidthCm']
12
+
13
  def predict(sepal_length, sepal_width, petal_length, petal_width):
14
+ data = pd.DataFrame([[sepal_length, sepal_width, petal_length, petal_width]], columns=feature_names)
15
+ prediction = model.predict(data)[0]
16
+ return f"Predicted Iris Class: {prediction}"
 
 
 
 
 
17
 
18
  # Gradio interface
19
  iface = gr.Interface(
20
  fn=predict,
21
  inputs=[
22
+ gr.Number(label="Sepal Length (cm)"),
23
+ gr.Number(label="Sepal Width (cm)"),
24
+ gr.Number(label="Petal Length (cm)"),
25
+ gr.Number(label="Petal Width (cm)"),
26
  ],
27
+ outputs=gr.Textbox(label="Prediction"),
28
+ title="Iris Species Predictor 🌸",
29
+ description="Enter flower features to predict the Iris species using a model trained with AutoTrain Tabular."
30
  )
31
 
32
  iface.launch()