Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import transforms | |
from torch.utils.data import DataLoader | |
from datasets import load_dataset | |
from huggingface_hub import HfApi, Repository | |
import os | |
import matplotlib.pyplot as plt | |
import utils | |
# Hugging Face Hub credentials | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MODEL_REPO_ID = "louiecerv/amer_sign_lang_data_augmentation" | |
DATASET_REPO_ID = "louiecerv/american_sign_language" | |
# Device configuration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
st.write(f"Device: {device}") | |
# Define the new CNN model | |
IMG_HEIGHT = 28 | |
IMG_WIDTH = 28 | |
IMG_CHS = 1 | |
N_CLASSES = 24 | |
class MyConvBlock(nn.Module): | |
def __init__(self, in_ch, out_ch, dropout_p): | |
kernel_size = 3 | |
super().__init__() | |
self.model = nn.Sequential( | |
nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1), | |
nn.BatchNorm2d(out_ch), | |
nn.ReLU(), | |
nn.Dropout(dropout_p), | |
nn.MaxPool2d(2, stride=2) | |
) | |
def forward(self, x): | |
return self.model(x) | |
flattened_img_size = 75 * 3 * 3 | |
# Input 1 x 28 x 28 | |
base_model = nn.Sequential( | |
MyConvBlock(IMG_CHS, 25, 0), # 25 x 14 x 14 | |
MyConvBlock(25, 50, 0.2), # 50 x 7 x 7 | |
MyConvBlock(50, 75, 0), # 75 x 3 x 3 | |
nn.Flatten(), | |
nn.Linear(flattened_img_size, 512), | |
nn.Dropout(.3), | |
nn.ReLU(), | |
nn.Linear(512, N_CLASSES) | |
) | |
# Streamlit app | |
def main(): | |
st.title("American Sign Language Recognition") | |
# Move slider and button to sidebar | |
num_epochs = st.sidebar.slider("Number of Epochs", 1, 20, 5) | |
train_button = st.sidebar.button("Train Model") | |
# Load the dataset from Hugging Face Hub | |
dataset = load_dataset(DATASET_REPO_ID) | |
# Data loaders with preprocessing and data augmentation: | |
random_transforms = transforms.Compose([ | |
transforms.RandomRotation(5), | |
transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.9, 1), ratio=(1, 1)), | |
transforms.RandomHorizontalFlip(), | |
transforms.ColorJitter(brightness=.2, contrast=.5), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
def collate_fn(batch): | |
images = [] | |
labels = [] | |
for item in batch: | |
if 'pixel_values' in item and 'label' in item: | |
image = torch.tensor(item['pixel_values']) | |
label = item['label'] | |
try: | |
image = random_transforms(image) | |
images.append(image) | |
labels.append(label) | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
continue | |
if not images: | |
return torch.tensor([]), torch.tensor([]) | |
images = torch.stack(images).to(device) | |
labels = torch.tensor(labels).long().to(device) | |
return images, labels | |
train_loader = DataLoader(dataset["train"], batch_size=64, shuffle=True, collate_fn=collate_fn) | |
val_loader = DataLoader(dataset["validation"], batch_size=64, collate_fn=collate_fn) | |
# Model, loss, and optimizer | |
model = base_model.to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
loss_history = [] | |
accuracy_history = [] | |
if train_button: | |
for epoch in range(num_epochs): | |
total = 0 | |
correct = 0 | |
epoch_loss = 0 | |
for i, (images, labels) in enumerate(train_loader): | |
if images.nelement() == 0: | |
continue | |
# Forward pass | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
epoch_loss += loss.item() | |
# Backward and optimize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
epoch_accuracy = 100 * correct / total | |
loss_history.append(epoch_loss / len(train_loader)) | |
accuracy_history.append(epoch_accuracy) | |
st.write(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.4f}, Accuracy: {epoch_accuracy:.2f}%') | |
# Plot loss and accuracy | |
fig, ax1 = plt.subplots() | |
ax2 = ax1.twinx() | |
ax1.plot(loss_history, 'g-', label='Loss') | |
ax2.plot(accuracy_history, 'b-', label='Accuracy') | |
ax1.set_xlabel('Epoch') | |
ax1.set_ylabel('Loss', color='g') | |
ax2.set_ylabel('Accuracy (%)', color='b') | |
plt.title('Training Loss and Accuracy') | |
st.pyplot(fig) | |
if __name__ == "__main__": | |
main() | |