louiecerv's picture
save changes
1cc1116
raw
history blame
4.82 kB
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from datasets import load_dataset
from huggingface_hub import HfApi, Repository
import os
import matplotlib.pyplot as plt
import utils
# Hugging Face Hub credentials
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_REPO_ID = "louiecerv/amer_sign_lang_data_augmentation"
DATASET_REPO_ID = "louiecerv/american_sign_language"
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.write(f"Device: {device}")
# Define the new CNN model
IMG_HEIGHT = 28
IMG_WIDTH = 28
IMG_CHS = 1
N_CLASSES = 24
class MyConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, dropout_p):
kernel_size = 3
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Dropout(dropout_p),
nn.MaxPool2d(2, stride=2)
)
def forward(self, x):
return self.model(x)
flattened_img_size = 75 * 3 * 3
# Input 1 x 28 x 28
base_model = nn.Sequential(
MyConvBlock(IMG_CHS, 25, 0), # 25 x 14 x 14
MyConvBlock(25, 50, 0.2), # 50 x 7 x 7
MyConvBlock(50, 75, 0), # 75 x 3 x 3
nn.Flatten(),
nn.Linear(flattened_img_size, 512),
nn.Dropout(.3),
nn.ReLU(),
nn.Linear(512, N_CLASSES)
)
# Streamlit app
def main():
st.title("American Sign Language Recognition")
# Move slider and button to sidebar
num_epochs = st.sidebar.slider("Number of Epochs", 1, 20, 5)
train_button = st.sidebar.button("Train Model")
# Load the dataset from Hugging Face Hub
dataset = load_dataset(DATASET_REPO_ID)
# Data loaders with preprocessing and data augmentation:
random_transforms = transforms.Compose([
transforms.RandomRotation(5),
transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.9, 1), ratio=(1, 1)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=.2, contrast=.5),
transforms.Normalize(mean=[0.5], std=[0.5])
])
def collate_fn(batch):
images = []
labels = []
for item in batch:
if 'pixel_values' in item and 'label' in item:
image = torch.tensor(item['pixel_values'])
label = item['label']
try:
image = random_transforms(image)
images.append(image)
labels.append(label)
except Exception as e:
print(f"Error processing image: {e}")
continue
if not images:
return torch.tensor([]), torch.tensor([])
images = torch.stack(images).to(device)
labels = torch.tensor(labels).long().to(device)
return images, labels
train_loader = DataLoader(dataset["train"], batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(dataset["validation"], batch_size=64, collate_fn=collate_fn)
# Model, loss, and optimizer
model = base_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_history = []
accuracy_history = []
if train_button:
for epoch in range(num_epochs):
total = 0
correct = 0
epoch_loss = 0
for i, (images, labels) in enumerate(train_loader):
if images.nelement() == 0:
continue
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
epoch_loss += loss.item()
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_accuracy = 100 * correct / total
loss_history.append(epoch_loss / len(train_loader))
accuracy_history.append(epoch_accuracy)
st.write(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.4f}, Accuracy: {epoch_accuracy:.2f}%')
# Plot loss and accuracy
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(loss_history, 'g-', label='Loss')
ax2.plot(accuracy_history, 'b-', label='Accuracy')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss', color='g')
ax2.set_ylabel('Accuracy (%)', color='b')
plt.title('Training Loss and Accuracy')
st.pyplot(fig)
if __name__ == "__main__":
main()