numpy.split | Split an array into multiple sub-array in Python

In this article, we will learn how to split an array into multiple subarrays in Python. So, for dividing an array into multiple subarrays, I am going to use numpy.split() function.

Split an array into multiple sub-arrays in Python

To understand numpy.split() function in Python we have to see the syntax of this function.
The syntax of this function is :

numpy.split(a,sections,axis)

A: Input array to be divided into multiple sub-arrays.

Sections: Sections or indices can be an integer or a 1-D array.

  • Integer: If the sections or indices is an integer (say n), then the input array will be divided into n equal arrays. But If such a split is not possible then the code will throw an error.
    For example, If an input array contains 9 elements, np.split(a,3) split the given array into 3 sub-arrays containing 3 elements each.
  • A 1-D array: If the sections or indices are a 1-D array then elements of this array should be in sorted order.
    For example,  np.split(a,[2,4,7]) split the array a into- a[0,1] , a[2,3] ,a[4,5,6] ,a[7,8] .

Axis: The axis along which to split. The default value of the axis is 0. This axis can be 0,1 or 2.

  • 0 represents the 1st axis or the horizontal axis. This split the array horizontally. Instead of using axis 0 we can also write np.hsplit       (a,  sections).
  • 1 represents the 2nd axis or the vertical axis. This split the array vertically. Instead of using axis 1, we can also write np.vsplit (a, sections).
  • 2 represents the 3rd axis. This split the array into multiple sub-arrays along the depth. Instead of using axis 2, we can also write  np.dsplit (a, sections).

Examples

 

import numpy as np
a=np.arange(9) 
print("1st array is\n",a)
print("2nd array is\n",np.split(a,[3,7])) #default value 0

In the above-given code, np.split(a,[3,4,7]) split the array a into 3 parts. One is a[:3],2nd is a[3:7] and 3rd is a[7:] and if you do not specify the value of the axis default value 0 will be set.

If you run the code output will be:

Output:
1st array is
 [0 1 2 3 4 5 6 7 8]
2nd array is
 [array([0, 1, 2]), array([3, 4, 5, 6]), array([7, 8])]

import numpy as np
A=np.arange(27).reshape(3,3,3)
a=np.split(A,3,0)    #split row-wise
print("1st array-\n",a)
b=np.split(A,3,1)  #split column-wise
print("2nd array-\n",b)
c=np.split(A,3,2)  #split depth-wise
print("3rd array-\n",c)

Here, we have split the array row-wise,column-wise and depth-wise by writing the value of the axis 0,1 and 2 respectively.

The output will be like:

Ouput:
1st array-
 [array([[[0, 1, 2],[3, 4, 5],[6, 7, 8]]])
,array([[[ 9, 10, 11],[12, 13, 14],[15, 16, 17]]])
,array([[[18, 19, 20],[21, 22, 23],[24, 25, 26]]])]
2nd array-
 [array([[[ 0,  1,  2]],[[ 9, 10, 11]],[[18, 19, 20]]])
,array([[[ 3,  4,  5]],[[12, 13, 14]],[[21, 22, 23]]])
,array([[[ 6,  7,  8]],[[15, 16, 17]],[[24, 25, 26]]])]


3rd array-
 [array([[[ 0],
        [ 3],
        [ 6]],

       [[ 9],
        [12],
        [15]],

       [[18],
        [21],
        [24]]]), array([[[ 1],
        [ 4],
        [ 7]],

       [[10],
        [13],
        [16]],
       [[19],
        [22],
        [25]]]), array([[[ 2],
        [ 5],
        [ 8]],

       [[11],
        [14],
        [17]],

       [[20],
        [23],
        [26]]])]

Also read: Check if a NumPy array contains any NaN value in Python

Leave a Reply