Skip to content

Image Classification with CNN

In this tutorial, we will build a simple Convolutional Neural Network (CNN) to classify handwritten digits from the MNIST dataset.

1. Import Libraries

python
import tensorplay as tp
import tensorplay.nn as nn
import tensorplay.optim as optim
from tensorplay.utils.data import DataLoader
from tensorplay.datasets import MNIST

2. Define the Model

We'll use two convolutional layers followed by two fully connected layers.

python
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.flatten(1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = CNN()

3. Training Setup

python
# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 5

# Data Loader
train_dataset = MNIST(root='./data', train=True, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

4. Training Loop

python
for epoch in range(epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

5. Evaluation

python
model.eval()
with tp.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = tp.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Test Accuracy: {100 * correct / total}%')

Released under the Apache 2.0 License.

📚DeepWiki