titipata commited on
Commit
f404fb6
·
1 Parent(s): 9a7b163

Initial commit

Browse files
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ import numpy as np
3
+ import pandas as pd
4
+ import gradio as gr
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset, DataLoader
9
+
10
+
11
+ brands = [
12
+ 'Toyota', 'Honda', 'Mazda', 'Mitsubishi',
13
+ 'Nissan', 'Suzuki'
14
+ ]
15
+ models = [
16
+ 'Vios', 'Altis', 'Civic', 'Mazda3', 'Camry',
17
+ 'Mirage', 'Brio', 'Lancer Ex', 'Jazz', 'Accord',
18
+ 'Lancer', 'Yaris', 'Almera', 'City', 'Swift', 'Mazda2',
19
+ 'Teana', 'Note', 'Celerio', 'March', 'Tiida', 'Prius',
20
+ 'Ciaz', 'Sylphy', 'Pulsar', 'Attrage', 'Sunny'
21
+ ]
22
+ engines = [
23
+ 1.5, 1.8, 1.7, 2.0, 1.2, 1.6, 2.4,
24
+ 2.5, 1.0, 1.3, 2.3, 3.0, 2.2
25
+ ]
26
+ segments = ['B-Segment', 'C-Segment', 'D-Segment', 'Eco Car']
27
+ provinces = [
28
+ 'สงขลา', 'กรุงเทพมหานคร', 'สระบุรี', 'ชัยนาท', 'ระยอง', 'นครสวรรค์',
29
+ 'นนทบุรี', 'ตาก', 'สมุทรสาคร', 'เชียงใหม่', 'ลำปาง', 'สุพรรณบุรี', 'เชียงราย',
30
+ 'เพชรบุรี', 'พิษณุโลก', 'นครปฐม', 'อุดรธานี', 'สมุทรปราการ', 'ปทุมธานี',
31
+ 'นครราชสีมา', 'ชลบุรี', 'ปัตตานี', 'ราชบุรี', 'ลำพูน', 'กระบี่', 'ฉะเชิงเทรา',
32
+ 'พัทลุง', 'อ่างทอง', 'ขอนแก่น', 'ปราจีนบุรี', 'สุราษฎร์ธานี', 'ภูเก็ต',
33
+ 'หนองบัวลำภู', 'พิจิตร', 'พะเยา', 'ตราด', 'นครศรีธรรมราช', 'บุรีรัมย์',
34
+ 'ลพบุรี', 'อุตรดิตถ์', 'ยโสธร', 'อุบลราชธานี', 'สิงห์บุรี', 'พระนครศรีอยุธยา',
35
+ 'กาฬสินธุ์', 'สกลนคร', 'ร้อยเอ็ด', 'ระนอง', 'นครพนม', 'อุทัยธานี', 'จันทบุรี',
36
+ 'มหาสารคาม', 'กาญจนบุรี', 'แพร่', 'บึงกาฬ', 'กำแพงเพชร', 'สมุทรสงคราม',
37
+ 'สุโขทัย', 'ตรัง', 'แม่ฮ่องสอน', 'อำนาจเจริญ', 'นครนายก', 'ชัยภูมิ', 'พังงา',
38
+ 'สระแก้ว', 'สุรินทร์', 'นราธิวาส', 'สตูล', 'ประจวบคีรีขันธ์', 'เพชรบูรณ์', 'ศรีสะเกษ',
39
+ 'หนองคาย', 'ยะลา', 'น่าน'
40
+ ]
41
+ colors = ['Gray', 'Black', 'Gold', 'Silver', 'Brown', 'White',
42
+ 'Red', 'Yellow', 'Blue', 'Green', 'Cyan', 'Orange']
43
+ examples = [
44
+ ['Honda', 'Civic', 1.8, 'C-Segment', 'ตรัง', 'Gray', 2009, 185477.0],
45
+ ['Honda', 'Accord', 2.4, 'D-Segment', 'ขอนแก่น', 'Black', 2003, 166508.0],
46
+ ['Honda', 'Jazz', 1.5, 'B-Segment', 'กรุงเทพมหานคร', 'White', 2011, 62000.0],
47
+ ['Honda', 'Civic', 1.8, 'C-Segment', 'พระนครศรีอยุธยา', 'White', 2012, 165346.0],
48
+ ['Suzuki', 'Swift', 1.2, 'Eco Car', 'กรุงเทพมหานคร', 'White', 2016, 193000.0],
49
+ ['Honda', 'City', 1.0, 'B-Segment', 'กรุงเทพมหานคร', 'Gray', 2020, 29000.0],
50
+ ['Honda', 'City', 1.5, 'B-Segment', 'พิษณุโลก', 'Gray', 2007, 126208.0],
51
+ ['Toyota', 'Yaris', 1.5, 'Eco Car', 'เชียงใหม่', 'White', 2013, 100000.0],
52
+ ['Toyota', 'Altis', 1.6, 'C-Segment', 'กรุงเทพมหานคร', 'Silver', 2009, 260000.0],
53
+ ['Honda', 'Civic', 1.8, 'C-Segment', 'กรุงเทพมหานคร', 'Silver', 2006, 232433.0],
54
+ ]
55
+ CAT_COLUMNS = ["Brand", "Model", "Engine", "Segment", "Province", "Color"]
56
+
57
+
58
+ class CarPriceDataset(Dataset):
59
+ def __init__(self, X, y = None):
60
+ self.X = X
61
+ if y is not None:
62
+ self.y = y
63
+ else:
64
+ self.y = None
65
+
66
+ def __len__(self):
67
+ return len(self.X)
68
+
69
+ def __getitem__(self, idx):
70
+ if self.y is not None:
71
+ return self.X[idx], self.y[idx]
72
+ else:
73
+ return self.X[idx]
74
+
75
+ class CarPriceTwoLayerModel(nn.Module):
76
+ def __init__(self, input_size, output_size, intermediate_dim = 10):
77
+ super().__init__()
78
+ self.linear1 = nn.Linear(input_size, intermediate_dim)
79
+ self.linear2 = nn.Linear(intermediate_dim, output_size)
80
+ self.relu = nn.ReLU()
81
+
82
+ def forward(self, x):
83
+ x = self.linear1(x)
84
+ x = self.relu(x)
85
+ x = self.linear2(x)
86
+ return x
87
+
88
+ # Load model
89
+ pred_model = CarPriceTwoLayerModel(138, 1)
90
+ pred_model.load_state_dict(torch.load("carprice_two_layer_model_mse_00015.pth"))
91
+
92
+ # Load one-hot encoder and scaler
93
+ ohe = joblib.load("one_hot_encoder.joblib")
94
+ year_scaler = joblib.load("year_scaler.joblib")
95
+ mileage_scaler = joblib.load("mileage_scaler.joblib")
96
+ price_scaler = joblib.load("price_scaler.joblib")
97
+
98
+
99
+ def predict(model, data_loader):
100
+ model.eval()
101
+ y_pred_list = []
102
+ for x in data_loader:
103
+ y_pred = model(x.float())
104
+ prediction = y_pred.detach().numpy()
105
+ y_pred_list.extend(prediction)
106
+ y_pred_list = np.concatenate(y_pred_list)
107
+ return y_pred_list
108
+
109
+
110
+ def predict_car_price(
111
+ brand: str, model: str, engine: float, segment: str, province: str,
112
+ color: str, year: float, mileage: float
113
+ ):
114
+ df = pd.DataFrame([{
115
+ "Brand": brand,
116
+ "Model": model,
117
+ "Engine": engine,
118
+ "Segment": segment,
119
+ "Province": province,
120
+ "Color": color,
121
+ "Year": year,
122
+ "Mileage": mileage,
123
+ }])
124
+ features = np.hstack([
125
+ ohe.transform(df[CAT_COLUMNS]),
126
+ year_scaler.transform(df[["Year"]]),
127
+ mileage_scaler.transform(df[["Mileage"]])
128
+ ])
129
+ feat_dataset = CarPriceDataset(features)
130
+ dataloaders = DataLoader(feat_dataset, batch_size=32, shuffle=False)
131
+ y_pred_lr = predict(pred_model, dataloaders)
132
+ return int(price_scaler.inverse_transform(y_pred_lr.reshape(-1, 1)).ravel()[0])
133
+
134
+
135
+ interface = gr.Interface(
136
+ fn=predict_car_price,
137
+ inputs=[
138
+ gr.Dropdown(brands, label="Brand", info="Select Car Brand"),
139
+ gr.Dropdown(models, label="Model", info="Select Car Model"),
140
+ gr.Dropdown(engines, label="Engine Size", info="Select Engine Size"),
141
+ gr.Dropdown(segments, label="Car segment", info="Select Car Segment"),
142
+ gr.Dropdown(provinces, label="Province", info="Select Province"),
143
+ gr.Dropdown(colors, label="Color", info="Select Color"),
144
+ gr.Slider(1990, 2023, label="Year", info="Select Year"),
145
+ gr.Slider(0, 400000, label="Mileage", info="Select Mileage"),
146
+ ],
147
+ outputs=gr.Textbox(label="ราคาทำนาย (บาท)", placeholder="xxx,xxx (บาท)"),
148
+ examples=examples,
149
+ title="ทำนายราคารถมือสอง",
150
+ description="ตัวอย่างแอพพลิเคชั่นสำหรับคำนวณราคารถมือสอง",
151
+ )
152
+ interface.launch()
carprice_two_layer_model_mse_00015.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de4e5ec96b38ff26a2395f90a07398aaf75aee33d1fe861c0d02ee8dc4382422
3
+ size 7553
mileage_scaler.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:765d6c3426f034c6cc807f0fc5576ec44dca78080db2e8eabc3f9d6e87ccb6cb
3
+ size 909
one_hot_encoder.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7c68c8fc8f04f21e440f60b3dfb6adc5364dd986a78540c495cb90d00b8265b
3
+ size 5034
price_scaler.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0dc64226c6ebd02dbde58eb62d52ba13d61cee5390cd410867902c8ba52b4c82
3
+ size 907
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ gradio_client==0.2.7
year_scaler.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86f4ce94683756213f272a70526c81ce2e5d087f5b9cc45dd9b373177d8667c3
3
+ size 906