diff --git a/projection_matrix/projection_matrix.py b/projection_matrix/projection_matrix.py
index b3c808af2c274ffa5a83e15f68b0f4e3875c5afb..12ac694b8ba4a5464f24a7122692f2e0b201b04a 100644
--- a/projection_matrix/projection_matrix.py
+++ b/projection_matrix/projection_matrix.py
@@ -48,7 +48,7 @@ def idx_array_slice(D, axis, slice_idx):
 
 
 def projection_matrix(D, xyz, labels=None, projection='max_slice',
-                      max_n_ticks=4, factor=3, **kwargs):
+                      max_n_ticks=4, factor=3, whspace=0.05, **kwargs):
     """ Generate a projection matrix plot
 
     Parameters
@@ -80,9 +80,8 @@ def projection_matrix(D, xyz, labels=None, projection='max_slice',
 
     """
     ndim = D.ndim
-    lbdim = 0.5 * factor   # size of left/bottom margin
+    lbdim = 0.4 * factor   # size of left/bottom margin
     trdim = 0.2 * factor   # size of top/right margin
-    whspace = 0.05         # w/hspace size
     plotdim = factor * ndim + factor * (ndim - 1.) * whspace
     dim = lbdim + plotdim + trdim
 
@@ -99,7 +98,7 @@ def projection_matrix(D, xyz, labels=None, projection='max_slice',
     # Format the figure.
     lb = lbdim / dim
     tr = (lbdim + plotdim) / dim
-    fig.subplots_adjust(left=lb, bottom=lb, right=tr, top=tr,
+    fig.subplots_adjust(left=lb, bottom=lb, right=0.98*tr, top=tr,
                         wspace=whspace, hspace=whspace)
     for i in range(ndim):
         projection_1D(