|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import numpy as np |
|
from safetensors.torch import save_file, load_file |
|
|
|
epochs = 10000 |
|
|
|
|
|
|
|
data = np.array([ |
|
[30, 0, 1], |
|
[22, 1, 0], |
|
[25, 0, 1], |
|
[15, 1, 0], |
|
[20, 0, 1], |
|
]) |
|
|
|
|
|
X = (data[:, :2] - np.mean(data[:, :2], axis=0)) / np.std(data[:, :2], axis=0) |
|
y = torch.tensor(data[:, 2], dtype=torch.float32).view(-1, 1) |
|
|
|
|
|
class SimpleNN(nn.Module): |
|
def __init__(self): |
|
super(SimpleNN, self).__init__() |
|
self.fc1 = nn.Linear(2, 8) |
|
self.fc2 = nn.Linear(8, 4) |
|
self.fc3 = nn.Linear(4, 1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = nn.ReLU()(x) |
|
x = self.fc2(x) |
|
x = nn.ReLU()(x) |
|
x = self.fc3(x) |
|
return self.sigmoid(x) |
|
|
|
|
|
model = SimpleNN() |
|
criterion = nn.BCELoss() |
|
optimizer = optim.Adam(model.parameters(), lr=0.01) |
|
|
|
|
|
for epoch in range(epochs): |
|
model.train() |
|
optimizer.zero_grad() |
|
output = model(torch.tensor(X, dtype=torch.float32)) |
|
loss = criterion(output, y) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if (epoch + 1) % 20 == 0: |
|
print(f'Epoch [{epoch + 1}/200], Loss: {loss.item():.4f}') |
|
|
|
|
|
save_file(model.state_dict(), f'WEATHER-RUN-{epochs}.safetensors') |
|
|
|
|
|
def load_and_run_model(model_path, input_data): |
|
model = SimpleNN() |
|
model.load_state_dict(load_file(model_path)) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
input_tensor = torch.tensor(input_data, dtype=torch.float32) |
|
output = model(input_tensor) |
|
return output.numpy() |
|
|
|
|
|
test_data = [[25, 0], [18, 1], [21, 0], [19, 1]] |
|
normalized_test_data = (np.array(test_data) - np.mean(data[:, :2], axis=0)) / np.std(data[:, :2], axis=0) |
|
predictions = load_and_run_model(f'WEATHER-RUN-{epochs}.safetensors', normalized_test_data) |
|
|
|
|
|
for (temp, rain), pred in zip(test_data, predictions): |
|
result = "can" if pred[0] >= 0.5 else "cannot" |
|
print(f"With temperature {temp}°C and {'rain' if rain else 'no rain'}, you {result} go for a run.") |
|
|
|
|