DEV Community

Sam Der
Sam Der

Posted on • Edited on

Getting Started With Machine Learning Using MNIST

What is Machine Learning?

Machine learning is one of the hottest topics in computer science in recent years. With the rise of software like ChatGPT by OpenAI, Copilot by Microsoft, and countless other large language models that are powered by machine learning, it's becoming ever more relevant in our daily lives. From helping us write emails to answering our questions about life and the pursuit of happiness, these chatbots seem almost omnipotent.

But how do we get from machine learning to large language models? How do we get from training a model that can, for example, identify handwritten digits to something as complex as ChatGPT? I'll be publishing posts every week that goes further in depth into answering these questions, but for now, let's start with the basics.

The Basics

Machine learning is exactly what it is: making a machine (or more specifically, a computer) learn. More precisely, it's writing code to train our computers so that it can be used to make predictions later down the line. The result of this training is called a model.

In the handwritten digits example, you might think that we can just use existing images to train a model, but we can't just use any image of number. Since we each have our own way of writing digits, we can't just train on a screenshot of a number and then hope it can extrapolate to our handwriting. The numbers for a specific font are rendered exactly the same way on a computer screen whereas each time we write a number, there are extremely minor differences that could convince a computer that the number we wrote down is a different one. We need to train our model using the same type of data we want it to predict on. In other words, we want to train our model with images of handwritten digits so that it can do the same with images it hasn't seen before.

Google Colab

To get started, we can head over to Google Colab, a notebook-based coding environment. What that means is that you can split your code into individual cells that you can re-run and you can view the output of each one.


A simple example using the Fibonacci sequence. Competitive programmers know what I'm doing here :). If you don't, then hopefully you learned something new!

Notice how you don't need to use the print function to view output in this scenario. The return value from the last line of a cell gets outputted by default, but if you want to see output from other lines, then you'll have to use print.

Training A Model With scikit-learn To Recognize Digits

Getting Our Data

There's a popular library to get started with machine learning called scikit-learn that's already installed in Colab. We can use this to train our model using the MNIST dataset of handwritten digits.

# Importing the necessary libraries and functions that will be used later
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.datasets import fetch_openml
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier
Enter fullscreen mode Exit fullscreen mode
mnist = fetch_openml('mnist_784')
mnist.data
Enter fullscreen mode Exit fullscreen mode


The dataset: Each row represents an image and each column represents the weight of a pixel in that image.

mnist.target
Enter fullscreen mode Exit fullscreen mode


The targets: This contains the digits each image corresponds to. For example, the image in the first row is a 5.

images = mnist.data.to_numpy()
plt.imshow(images[0].reshape(28, 28), cmap='gray')
Enter fullscreen mode Exit fullscreen mode


An example image rendered using numpy and matplotlib.

Now that we have our images, let's proceed with building our model! Since we're classifying images to a certain digit and also just to keep things simple, we should use a classification algorithm. One of the most common algorithms of this type is k-nearest neighbors, which looks at data points that are near in terms of Euclidean distance (or some other distance metric) and outputs the classification that is most common among them. This is appropriate in our scenario because images that have similar pixel weights should represent the same number.

Running Our Classifier

To train our model and test it appropriately, we'll need to split this large dataset into smaller training and testing datasets. Luckily, scikit-learn provides a function for us named train_test_split to do just that! However, it requires that mnist.data and mnist.target be combined into one data structure first. We can join them together using the join method provided by the pandas library.

mnist.data.join(mnist.target)
X_train, X_test, y_train, y_test = train_test_split(images, mnist.target.to_numpy(), test_size=0.2)
Enter fullscreen mode Exit fullscreen mode

The test size can be adjusted arbitrarily but usually 80/20 is a good split.

We can then use the KNeighborsClassifer that we imported at the beginning to train our model using our training dataset and then test the accuracy of our model using the test dataset. This snippet looks at the 5 nearest neighbors but feel free to change the number as you see fit!

model = KNeighborsClassifier(n_neighbors=5)
model.fit(X_train, y_train)
accuracy_score(y_test, model.predict(X_test))
Enter fullscreen mode Exit fullscreen mode

This output means that our model had an accuracy of 97% with the test dataset. That's a pretty good score! And congratulations, you just built your first machine learning model!

Where To Go From Here?

We used the scikit-learn library in this tutorial, but for more intensive machine learning applications that, for instance, require more customization of your parameters or utilization of compute resources, it would be better to use a library like PyTorch or Tensorflow. I'll be going over both in later posts!

In the meantime, thanks for sticking with me. See you in the next one!

Top comments (0)