Transpose Convolution to reverse Convolution operation

Traditional convolutional layer takes a patch of an image and produces a number (patch -> number). In “transpose convolution” we want to take a number and produce a patch of an image (number -> patch). We need this layer to “undo” convolutions in encoder. This is used specifically for a decoder type operation.

Test Image

In this example we use a stride of 2 to produce 4x4 output, this way we “undo” pooling as well. Another way to think about it: we “undo” convolution with stride 2 (which is similar to conv + pool).

You can add “transpose convolution” layer in Keras like this:

L.Conv2DTranspose(filters=?, kernel_size=(3, 3), strides=2, activation=’elu’, padding=’same’)

Check if your decoder starts with a dense layer to “undo” the last layer of encoder. Remember to reshape its output to “undo” L.Flatten() in encoder.

Code Example

import numpy as np
import tensorflow as tf

def test_conv2d_transpose(img_size, filter_size):
    print("Transpose convolution test for img_size={}, filter_size={}:".format(img_size, filter_size))
    
    x = (np.arange(img_size ** 2, dtype=np.float32) + 1).reshape((1, img_size, img_size, 1))
    f = (np.ones(filter_size ** 2, dtype=np.float32)).reshape((filter_size, filter_size, 1, 1))

    with tf.Session() as s:
    
        conv = tf.nn.conv2d_transpose(x, f, 
                                  output_shape=(1, img_size * 2, img_size * 2, 1), 
                                  strides=[1, 2, 2, 1], 
                                  padding='SAME')

        result = s.run(conv)
        print("input:")
        print(x[0, :, :, 0])
        print("filter:")
        print(f[:, :, 0, 0])
        print("output:")
        print(result[0, :, :, 0])
    
        
test_conv2d_transpose(img_size=2, filter_size=2)
test_conv2d_transpose(img_size=2, filter_size=3)
test_conv2d_transpose(img_size=4, filter_size=2)
test_conv2d_transpose(img_size=4, filter_size=3)

Outputs

Transpose convolution test for img_size=2, filter_size=2:
input:
[[1. 2.]
 [3. 4.]]
filter:
[[1. 1.]
 [1. 1.]]
output:
[[1. 1. 2. 2.]
 [1. 1. 2. 2.]
 [3. 3. 4. 4.]
 [3. 3. 4. 4.]]
   
 
Transpose convolution test for img_size=2, filter_size=3:
input:
[[1. 2.]
 [3. 4.]]
filter:
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
output:
[[ 1.  1.  3.  2.]
 [ 1.  1.  3.  2.]
 [ 4.  4. 10.  6.]
 [ 3.  3.  7.  4.]]  
 
 
Transpose convolution test for img_size=4, filter_size=2:  
input:
[[ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]
 [ 9. 10. 11. 12.]
 [13. 14. 15. 16.]]
filter:
[[1. 1.]
 [1. 1.]]
output:
[[ 1.  1.  2.  2.  3.  3.  4.  4.]
 [ 1.  1.  2.  2.  3.  3.  4.  4.]
 [ 5.  5.  6.  6.  7.  7.  8.  8.]
 [ 5.  5.  6.  6.  7.  7.  8.  8.]
 [ 9.  9. 10. 10. 11. 11. 12. 12.]
 [ 9.  9. 10. 10. 11. 11. 12. 12.]
 [13. 13. 14. 14. 15. 15. 16. 16.]
 [13. 13. 14. 14. 15. 15. 16. 16.]]  
   
   
Transpose convolution test for img_size=4, filter_size=3:  
input:
[[ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]
 [ 9. 10. 11. 12.]
 [13. 14. 15. 16.]]
filter:
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
output:
[[ 1.  1.  3.  2.  5.  3.  7.  4.]
 [ 1.  1.  3.  2.  5.  3.  7.  4.]
 [ 6.  6. 14.  8. 18. 10. 22. 12.]
 [ 5.  5. 11.  6. 13.  7. 15.  8.]
 [14. 14. 30. 16. 34. 18. 38. 20.]
 [ 9.  9. 19. 10. 21. 11. 23. 12.]
 [22. 22. 46. 24. 50. 26. 54. 28.]
 [13. 13. 27. 14. 29. 15. 31. 16.]]
Written on October 2, 2018
]