Tensoflow Implementation of [Die16] Dieleman, Sander, Jeffrey De Fauw, and Koray Kavukcuoglu; "Exploiting cyclic symmetry in convolutional neural networks"
A function $f$ is equivariant to a class of transformations $\mathcal T$ if $$ \forall \mathbf T \in \mathcal T: \exists \mathbf T' : f(\mathbf Tx) = \mathbf T'f(x) $$ with a transformation $\mathbf T'$
A function is same-equivariant if it's equivariant and $\mathbf T = \mathbf T'$.
A function $f$ is invariant to a class of transformation $\mathbf T \in \mathcal T$ if $$ f(\mathbf Tx) = f(x) $$
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from libs import utils
from skimage.transform import resize
dirname = '/Users/christian/lehre-extern/tensorflow/CADL_my/session-1/img_align_celeba/'
# Load every image file in the provided directory
filenames = [os.path.join(dirname, fname)
for fname in os.listdir(dirname)]
# Make sure we have exactly 100 image files!
filenames = filenames[:100]
assert(len(filenames) == 100)
# Read every filename as an RGB image
imgs = [plt.imread(fname)[..., :3] for fname in filenames]
# Crop every image to a square
imgs = [utils.imcrop_tosquare(img_i) for img_i in imgs]
# Then resize the square image to 100 x 100 pixels
imgs = [resize(img_i, (100, 100)) for img_i in imgs]
imgs = np.array(imgs)#, dtype = 'float32')
#plt.figure(figsize=(3, 3))
#plt.imshow(imgs[1])
batch_size=10
X = imgs[:batch_size]
X.shape
with
The size of the minibatch increases (4 times). The rotated examples are stacked across the batch dimension.
with
The size of the minibatch increases (8 times). The rotated and/or flipped examples are stacked across the batch dimension.
import tensorflow as tf
g = tf.get_default_graph()
sess = tf.Session(graph=g)
def get_rotations(X, r):
p = len(r)
X_ = [X[i] for i in range(p)]
X_rot = [tf.map_fn(lambda img: tf.image.rot90(img, r[i]), x)
for i, x in enumerate(X_)]
return X_rot
def concat_horizonal_flipped(X):
X_m = tf.reverse(X, [False, False, True, False])
return tf.concat(0, (X, X_m))
X_hflipped = concat_horizonal_flipped(X)
X_r = sess.run(X_hflipped)
plt.figure(figsize=(3, 3))
plt.imshow(X_r[10])
def cyclic_slicing(X):
X_ = get_rotations([X]*4, r=(0, 3, 2, 1))
return tf.concat(0, X_)
def dyhydral_slicing(X):
X_ = concat_horizonal_flipped(X)
return cyclic_slicing(X_)
X_sliced_4 = cyclic_slicing(X)
X_s_4 = sess.run(X_sliced_4)
X_sliced_8 = dyhydral_slicing(X)
X_s_8 = sess.run(X_sliced_8)
plt.figure(figsize=(3, 3))
plt.imshow(X_s_8[20])
Input to pooling layer
$$ X = [X_0, X_1, X_2, X_3 ]^T $$$$ P(X) = p(X_0, r^{-1}X_1,r^{-2}X_2,r^{-3}X_3) $$with
Pooling is typically done in the dense layers. So no back rotation (or flipping) is needed.
def dense_pooling(X, pool_op, p=8):
# pool_op has to be a tf reduce operator
# no spatial structure -> no backrotation (r^-1) needed
shape = tf.shape(X)
X_ = tf.reshape(X, (p, shape[0]//p, shape[1]))
return pool_op(X_, (0))
X_s_8_shape = X_s_8.shape
X_s_8_ = X_s_8.reshape((X_s_8_shape[0], -1))
print (X_s_8_.shape)
pool_op = tf.reduce_mean
sess.run(dense_pooling(X_s_8_, pool_op, p=8)).shape
$X$ is minibatch
$$ S(rX) = [rX, r^2X, r^3X, X]^T = \sigma S(X) $$with the cyclic permutation $\sigma$ shifting the elements backwards.
def segmentation(X, p=8):
shape = tf.shape(X)
return tf.reshape(X, (p, shape[0]//p, shape[1], shape[2], shape[3]))
# just for visualization purpose
def add_colors_8(X):
shape = np.shape(X)
X = X.copy()
X_ = np.reshape(X, (8, shape[0]//8, shape[1], shape[2], shape[3]))
X_[0,:,:,:,0] += 0.3
X_[1,:,:,:,0] += 0.5
X_[2,:,:,:,1] += 0.3
X_[3,:,:,:,1] += 0.5
X_[4,:,:,:,2] += 0.3
X_[5,:,:,:,2] += 0.5
X_[6,:,:,:,0] += 0.3
X_[6,:,:,:,1] += 0.3
X_[7,:,:,:,0] += 0.5
X_[7,:,:,:,1] += 0.5
X_[X_>1.] = 1.
return np.concatenate((X_[0], X_[1],X_[2], X_[3], X_[4], X_[5],X_[6], X_[7]), 0)
# just for visualization purpose
def add_colors_4(X):
shape = np.shape(X)
X = X.copy()
X_ = np.reshape(X, (4, shape[0]//4, shape[1], shape[2], shape[3]))
X_[0,:,:,:,0] += 0.5
X_[1,:,:,:,1] += 0.5
X_[2,:,:,:,2] += 0.5
X_[3,:,:,:,0] += 0.5
X_[3,:,:,:,1] += 0.5
X_[X_>1.] = 1.
return np.concatenate((X_[0], X_[1],X_[2], X_[3]), 0)
X_colors_4 = add_colors_4(X_s_4)
plt.figure(figsize=(2, 6))
p=4
for i in range(p):
plt.subplot(p, 1, i+1)
plt.imshow(X_colors_4[batch_size*i, :, :, :])
X_colors_8 = add_colors_8(X_s_8)
plt.figure(figsize=(5, 12))
p=8
for i in range(p):
plt.subplot(p, 1, i+1)
plt.imshow(X_colors_8[batch_size*i, :, :, :])
def cyclic_permutation(X, shift=1, p=8):
perm = np.roll(range(p), shift=-shift)
X_ = segmentation(X, p=p)
X_concat = [X_[perm[i]] for i in range(p)]
return tf.concat(0, X_concat)
X_p = sess.run(cyclic_permutation(X_sliced_4, 2, 8))
X_p.shape
plt.figure(figsize=(3, 3))
plt.imshow(X_p[0])
Stacking along the feature dimensions:
$$ T(X) = [X_0, r^{-1}X_1, r^{-2}X_2, r^{-3}X_3] $$with $X = [X_0, X_1, X_2, X_3]^T$
def stack(X):
X_ = segmentation(X, 4)
X_ = [X_[i] for i in range(4)]
r = (0, 1, 2, 3)
X_ = get_rotations(X_, r=r)
X_concat = [X_[i] for i in range(4)]
return tf.concat(3, X_concat)
X_stacked_4 = sess.run(stack(cyclic_permutation(X_colors_4, 3, p=4)))
X_stacked_4.shape
plt.figure(figsize=(6, 2.5))
p=4
for i in range(p):
plt.subplot(1, p, i+1)
plt.imshow(X_stacked_4[0, :, :, (i*3):(i*3)+3])