bosayama commited on
Commit
45990d1
·
verified ·
1 Parent(s): a1dc0fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -28
app.py CHANGED
@@ -1,38 +1,34 @@
1
  import gradio as gr
2
- import cv2
3
- import numpy as np
4
  import tensorflow as tf
 
 
5
 
6
- # Teachable MachineモデルのURL
7
- URL = "https://teachablemachine.withgoogle.com/models/ZPfAhDYCh/"
8
-
9
- # モデルとメタデータのURL
10
- model_url = URL + "model.json"
11
- metadata_url = URL + "metadata.json"
12
-
13
- # モデルとメタデータのロード
14
- model = tf.keras.models.load_model(model_url)
15
- metadata = gr.tfjs.ModelMetadata(metadata_url)
16
-
17
- # ウェブカメラのキャプチャ
18
- cap = cv2.VideoCapture(0)
19
 
20
- # カメラからの入力を処理する関数
21
- def classify_image(frame):
22
- # フレームをリサイズしてモデルに適した形式に変換
23
- input_data = cv2.resize(frame, (200, 200))
24
- input_data = np.expand_dims(input_data, axis=0)
 
25
 
26
- # 画像の予測
27
- predictions = model.predict(input_data)
 
 
 
28
 
29
- # 予測結果を表示
30
- results = {metadata.get_class_label(i): float(predictions[0][i]) for i in range(len(predictions[0]))}
31
- return results
 
32
 
33
- # インターフェースの作成
34
- iface = gr.Interface(fn=classify_image, inputs="webcam", outputs="label")
35
 
36
- # インターフェースの起動
37
  iface.launch()
38
 
 
 
1
  import gradio as gr
 
 
2
  import tensorflow as tf
3
+ from tensorflow.keras.preprocessing import image
4
+ import numpy as np
5
 
6
+ # Load the Teachable Machine Image Model
7
+ model_url = "https://teachablemachine.withgoogle.com/models/ZPfAhDYCh/"
8
+ model = tf.keras.models.load_model(model_url + "model.json")
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Function to preprocess the image before making predictions
11
+ def preprocess_image(img):
12
+ img = image.img_to_array(img)
13
+ img = np.expand_dims(img, axis=0)
14
+ img /= 255.0 # Normalize pixel values to be between 0 and 1
15
+ return img
16
 
17
+ # Function to make predictions using the loaded model
18
+ def predict_image(img):
19
+ img = preprocess_image(img)
20
+ prediction = model.predict(img)
21
+ return {class_name: float(prediction[0][i]) for i, class_name in enumerate(classes)}
22
 
23
+ # Fetch class labels from metadata.json
24
+ metadata_url = model_url + "metadata.json"
25
+ metadata = tf.keras.utils.get_file("metadata.json", metadata_url)
26
+ classes = tf.keras.utils.get_file("classes.txt", metadata_url.replace("metadata.json", "classes.txt")).read().splitlines()
27
 
28
+ # Create Gradio interface
29
+ iface = gr.Interface(fn=predict_image, inputs="image", outputs="label", live=True, capture_session=True)
30
 
31
+ # Launch Gradio interface
32
  iface.launch()
33
 
34
+