Understanding the Need for expand_dims
Imagine you have a NumPy array representing a single data point, such as a grayscale image pixel value:
import numpy as np
= np.array(100)
pixel_value print(pixel_value.shape) # Output: () (a 0-dimensional array)
Many machine learning algorithms expect input data to have a specific number of dimensions. For example, a model might anticipate a batch of images, where each image is represented as a 2D array (height x width), and the batch itself forms a third dimension. Simply passing pixel_value
won’t work. This is where expand_dims
steps in.
Adding Dimensions with expand_dims
The expand_dims
function takes two arguments:
a
: The input array.axis
: The position where you want to insert the new axis. Remember that Python uses zero-based indexing.
Let’s illustrate:
= np.expand_dims(pixel_value, axis=0)
row_vector print(row_vector.shape) # Output: (1,)
= np.expand_dims(pixel_value, axis=1)
column_vector print(column_vector.shape) # Output: (1, 1)
Notice how we’ve transformed a 0-dimensional array into 1-dimensional arrays with different orientations.
Expanding Higher-Dimensional Arrays
The power of expand_dims
extends beyond simple scalar values. Let’s consider a 2D array:
= np.array([[1, 2], [3, 4]])
image print(image.shape) # Output: (2, 2)
= np.expand_dims(image, axis=0)
batched_image print(batched_image.shape) # Output: (1, 2, 2)
Here, we’ve added a new dimension at the beginning, effectively creating a batch of size one. This is a common preprocessing step in deep learning.
Practical Application: Reshaping for Model Input
Many machine learning models require a specific input shape. Let’s say a model expects input of shape (1, 28, 28)
(a batch size of 1, 28x28 image). You have an image represented as a (28,28)
NumPy array:
= np.random.rand(28, 28)
single_image = np.expand_dims(single_image, axis=0)
reshaped_image print(reshaped_image.shape) #Output: (1, 28, 28)
This example showcases how expand_dims
neatly prepares your data for model ingestion. You can similarly use it to add dimensions for channels (e.g., for color images), or any other required dimension.
expand_dims
vs. reshape
While reshape
can also alter the shape of an array, expand_dims
is specifically designed for adding single-dimensional axes, making it more concise and readable for this particular task. Using reshape
for this purpose would require more careful consideration of the array’s dimensions and would be less intuitive.
Using np.newaxis
for a More Concise Syntax
An alternative to using np.expand_dims
is to use np.newaxis
within array slicing, providing a slightly more compact notation:
= np.random.rand(28,28)
single_image = single_image[np.newaxis, ...] # adds a new axis at the beginning. ... represents all other axes.
reshaped_image print(reshaped_image.shape)
This approach offers a more concise way to add dimensions in specific situations. However, np.expand_dims
remains more explicit and arguably easier to understand for beginners.