Custom Training A Tensorflow Model
This tutorial shows you how to train a machine learning model with a custom training loop to categorize penguins by species. In this notebook, you use TensorFlow to accomplish the following:
Import a dataset
Build a simple linear model
Train the model
Evaluate the model's effectiveness
Use the trained model to make predictions
Penguin classification problem
Imagine you are an ornithologist seeking an automated way to categorize each penguin you find. Machine learning provides many algorithms to classify penguins statistically. For instance, a sophisticated machine learning program could classify penguins based on photographs. The model you build in this tutorial is a little simpler. It classifies penguins based on their body weight, flipper length, and beaks, specifically the length and width measurements of their culmen.
There are 18 species of penguins, but in this tutorial, you will only attempt to classify the following three:
Chinstrap penguins
Gentoo penguins
Adélie penguins
Fortunately, a research team has already created and shared a dataset of 334 penguins with bodyweight, flipper length, beak measurements, and other data. This dataset is also conveniently available as the penguins TensorFlow Dataset.
Setup
Install the tfds-nightly
package for the penguin’s dataset. The tfds-nightly
package is the nightly released version of the TensorFlow Datasets (TFDS). For more information on TFDS, see the TensorFlow Datasets overview.
pip install -q tfds-nightly
Then select Runtime > Restart Runtime from the Colab menu to restart the Colab runtime.
Do not proceed with the rest of this tutorial without first restarting the runtime.
Import TensorFlow and the other required Python modules.
import os
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
print("TensorFlow version: {}".format(tf.__version__))
print("TensorFlow Datasets version: ",tfds.__version__)
TensorFlow version: 2.9.0-rc1
TensorFlow Datasets version: 4.5.2+nightly
Import the dataset
The default penguins/processed TensorFlow Dataset is already cleaned, normalized, and ready for building a model. Before you download the processed data, preview a simplified version to get familiar with the original penguin survey data.
Preview the data
Download the simplified version of the penguin’s dataset (penguins/simple
) using the TensorFlow Datasets tdfs.load
method. There are 344 data records in this dataset. Extract the first five records into an DataFrame
object to inspect a sample of the values in this dataset:
ds_preview, info = tfds.load('penguins/simple', split='train', with_info=True)
df = tfds.as_dataframe(ds_preview.take(5), info)
print(df)
print(info.features)
body_mass_g culmen_depth_mm culmen_length_mm flipper_length_mm island \
0 4200.0 13.9 45.500000 210.0 0
1 4650.0 13.7 40.900002 214.0 0
2 5300.0 14.2 51.299999 218.0 0
3 5650.0 15.0 47.799999 215.0 0
4 5050.0 15.8 46.299999 215.0 0
sex species
0 0 2
1 0 2
2 1 2
3 1 2
4 1 2
FeaturesDict({
'body_mass_g': tf.float32,
'culmen_depth_mm': tf.float32,
'culmen_length_mm': tf.float32,
'flipper_length_mm': tf.float32,
'island': ClassLabel(shape=(), dtype=tf.int64, num_classes=3),
'sex': ClassLabel(shape=(), dtype=tf.int64, num_classes=3),
'species': ClassLabel(shape=(), dtype=tf.int64, num_classes=3),
})
2022-04-27 01:32:48.776548: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
The numbered rows are data records, one example per line, where:
The first six fields are features: these are the characteristics of an example. Here, the fields hold numbers representing penguin measurements.
The last column is the label: this is the value you want to predict. For this dataset, it's an integer value of 0, 1, or 2 that corresponds to a penguin species name.
In the dataset, the label for the penguin species is represented as a number to make it easier to work within the model you are building. These numbers correspond to the following penguin species:
0
: Adélie penguin1
: Chinstrap penguin2
: Gentoo penguin
Create a list containing the penguin species’ names in this order. You will use this list to interpret the output of the classification model:
class_names = ['Adélie', 'Chinstrap', 'Gentoo']
Download the preprocessed dataset
Now, download the preprocessed penguins dataset (penguins/processed
) with the tfds.load
method, which returns a list of tf.data.Dataset
objects. Note that the penguins/processed
the dataset doesn't come with its own test set, so use an 80:20 split to slice the full dataset into the training and test sets. You will use the test dataset later to verify your model.
ds_split, info = tfds.load("penguins/processed", split=['train[:20%]', 'train[20%:]'], as_supervised=True, with_info=True)
ds_test = ds_split[0]
ds_train = ds_split[1]
assert isinstance(ds_test, tf.data.Dataset)
print(info.features)
df_test = tfds.as_dataframe(ds_test.take(5), info)
print("Test dataset sample: ")
print(df_test)
df_train = tfds.as_dataframe(ds_train.take(5), info)
print("Train dataset sample: ")
print(df_train)
ds_train_batch = ds_train.batch(32)
FeaturesDict({
'features': Tensor(shape=(4,), dtype=tf.float32),
'species': ClassLabel(shape=(), dtype=tf.int64, num_classes=3),
})
2022-04-27 01:32:52.638111: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Test dataset sample:
features species
0 [0.6545454, 0.22619048, 0.89830506, 0.6388889] 2
1 [0.36, 0.04761905, 0.6440678, 0.4027778] 2
2 [0.68, 0.30952382, 0.91525424, 0.6944444] 2
3 [0.6181818, 0.20238096, 0.8135593, 0.6805556] 2
4 [0.5527273, 0.26190478, 0.84745765, 0.7083333] 2
Train dataset sample:
features species
0 [0.49818182, 0.6904762, 0.42372882, 0.4027778] 0
1 [0.48, 0.071428575, 0.6440678, 0.44444445] 2
2 [0.7236364, 0.9047619, 0.6440678, 0.5833333] 1
3 [0.34545454, 0.5833333, 0.33898306, 0.3472222] 0
4 [0.10909091, 0.75, 0.3559322, 0.41666666] 0
2022-04-27 01:32:52.926157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Notice that this version of the dataset has been processed by reducing the data down to four normalized features and a species label. In this format, the data can be quickly used to train a model without further processing.