Home Implementing structured pruning in python
Post
Cancel

Implementing structured pruning in python

Pruning

Pruning is a technique in machine learning that reduces the size of the model by removing the unnecessary parameters. It is a technique to reduce the size of the model and the computation required to train the model. Pruning can be done in two ways:

  • Structured
  • Unstructured

I have discussed about the pruning techniques in detail in the following blog post Pruning Techniques

Implementation

For our implementation let’s say we take a overly complicated neural network with 5 layers. The size of each layer is as follows:

1
2
3
4
5
6
7
8
9
10
11
12
model = nn.Sequential(
    nn.Linear(2, 20),
    nn.ReLU(),
    nn.Linear(20, 18),
    nn.ReLU(),
    nn.Linear(18, 16),
    nn.ReLU(),
    nn.Linear(16, 14),
    nn.ReLU(),
    nn.Linear(14, 2),
    nn.Sigmoid()
)

For implementing one shot pruning, we can define some utility functions that will help us in the process.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Training loop
def train(model,X_train_tensor, X_val_tensor,y_train_tensor, y_val_tensor, epochs = 100):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 10 == 0:
            model.eval()
            with torch.no_grad():
                val_outputs = model(X_val_tensor)
                val_loss = criterion(val_outputs, y_val_tensor)
                print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}")

def get_data(weight_data, layer_shape, layers_pruned):
    return weight_data[:, [col for col in range(layer_shape[1]) if col not in layers_pruned]]

def add_layer(unpruned_layers, input_shape, output_shape, layer_data, activation = nn.ReLU()):
	layer = nn.Linear(input_shape, output_shape)
	with torch.no_grad():
		layer.data = layer_data
	unpruned_layers.append(layer)
	unpruned_layers.append(activation)
get_data()
This function takes in the weight data, layer shape and the layers pruned as input and returns the weight data of the unpruned layers.
add_layer()
This function takes in the unpruned layers, input shape, output shape, layer data and activation function as input and adds the layer to the unpruned layers, with edges having their assigned weights.

One-shot Pruning (Reinitialization with Trained Weights)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# In this method, we prune prune_ratio fatures in each layer
# The nn.Sequential method randomly initialises when called
def oneshot_pruning( post_training_model, input_shape, output_shape, prune_ratio = 0.2):
    unpruned_layers = [] 
    layer_index = 0
    layers_pruned = []
    for name, param in post_training_model.named_parameters():
        if 'weight' in name:
            # Not pruning the last output layer
            if layer_index == len(post_training_model)-2:
                add_layer(unpruned_layers,
			input_shape,
			output_shape,
			get_data(post_training_model[layer_index].weight.data,
				param.data.shape,
				layers_pruned),
			nn.Sigmoid())
                continue
            # Sorting the features in a layer based on l1 norm
            param_with_skipped_input = get_data(post_training_model[layer_index].weight.data, param.data.shape, layers_pruned)
            sorted_layers = torch.linalg.norm(param_with_skipped_input, ord=1, dim=1).argsort(dim=-1)
            layers_not_pruned = sorted(sorted_layers[int(prune_ratio*param_with_skipped_input.shape[0]):])
            layers_pruned = sorted(sorted_layers[:int(prune_ratio*param_with_skipped_input.shape[0])])

            # Initialising unpruned neurons with pre-training values
            layer_data = param_with_skipped_input[layers_not_pruned, :] 
            add_layer(unpruned_layers, input_shape, layer_data.shape[0], layer_data)
            input_shape = layer_data.shape[0]
            #skipping every alternate relu layer
            layer_index=layer_index+2
    return nn.Sequential(*unpruned_layers)  
if statement at line 10
We don’t prune the last layer because we have a classification model. add_layer() function is called to add the last layer to the unpruned layers.
if not the last layer (after line 20)
We sort the features in a layer based on l1 norm. We then remove the features with the lowest l1 norm. We store the index of the removed features in an array - layers_pruned. We store the index of the unremoved features in an array - layers_not_pruned. We then add the unpruned layers to the unpruned_layers array. We then add the activation layers in between and repeat the above method till we reach the last layer.

One-shot Pruning (Reinitialization with Intialized Weights)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def oneshot_pruning_reinit( post_training_model, pre_training_model, input_shape, output_shape, prune_ratio = 0.2):
    unpruned_layers = [] 
    layer_index = 0
    layers_pruned = []
    for name, param in post_training_model.named_parameters():
        if 'weight' in name:
            # Not pruning the last output layer
            if layer_index == len(post_training_model)-2:
                add_layer(unpruned_layers,
			input_shape,
			output_shape,
			get_data(post_training_model[layer_index].weight.data,
				param.data.shape,
				layers_pruned),
			nn.Sigmoid())
                continue
            # Sorting the features in a layer based on l1 norm
            param_with_skipped_input = pre_training_model[layer_index].weight.data[:, [col for col in range(param.data.shape[1]) if col not in layers_pruned]]
            sorted_layers = torch.linalg.norm(param_with_skipped_input, ord=1, dim=1).argsort(dim=-1)
            layers_not_pruned = sorted(sorted_layers[int(prune_ratio*param_with_skipped_input.shape[0]):])
            layers_pruned = sorted(sorted_layers[:int(prune_ratio*param_with_skipped_input.shape[0])])

            # Initialising unpruned neurons with pre-training values
            layer_data = param_with_skipped_input[layers_not_pruned, :] 
            add_layer(unpruned_layers, input_shape, layer_data.shape[0], layer_data)
            input_shape = layer_data.shape[0]
            #skipping every alternate relu layer
            layer_index=layer_index+2
    model = nn.Sequential(*unpruned_layers)
    index = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            param.data = unpruned_layers[index].data
            index=index+2
    return model

This is similar to the above method, except here we re-initialize the model with the values of the pre-training model. (at line 18)

This post is licensed under CC BY 4.0 by the author.