diff --git a/python/MS_helper_functions.py b/python/MS_helper_functions.py
index aff1efc2f5af59b59cce52a09a7823bd4f70f373..24ef6f2c952562fc0b24217c189ad4d15b7069a7 100644
--- a/python/MS_helper_functions.py
+++ b/python/MS_helper_functions.py
@@ -883,9 +883,15 @@ class cl_MS():
             kwargs.pop("dim")
 
         self.is2D1D = False
-        if self.base_space.mesh.dim != self.dim:
+        if "is2D1D" in kwargs.keys():
+            self.is2D1D = kwargs["is2D1D"]
+            kwargs.pop("is2D1D")
+
+        elif (self.base_space.mesh.dim == 2 and self.dim == 3):
             self.is2D1D = True
 
+        
+
         self.zeroCoupling = {}
         if "zeroCoupling" in kwargs.keys():
             self.zeroCoupling = kwargs["zeroCoupling"]
@@ -898,6 +904,7 @@ class cl_MS():
         if len(kwargs.keys()) != 0:
             print("cl_MS unknown kewords:")
             print(kwargs.keys())
+            input()
 
         self.coupling_map = None
         self.sol_pack = []
@@ -1200,8 +1207,8 @@ class cl_MS():
                     # boundary terms
                     ret += InnerProduct(self.coupling_map[phi_i][phi_j] * terms_u[i][0] * specialcf.normal(self.dim), terms_v[j][0] * specialcf.normal(self.dim))
                 else:
-                    if (self.is2D1D and hasattr(self.coupling_map[phi_i][phi_j], "dims") and len(self.coupling_map[phi_i][phi_j].dims)> 0) :#or \
-                        #(hasattr(self.coupling_map[phi_i][phi_j], "dims") and hasattr(terms_u[i][0], "dims") and self.coupling_map[phi_i][phi_j].dims[0] != terms_u[i][0].dim):
+                    if (self.is2D1D and hasattr(self.coupling_map[phi_i][phi_j], "dims") and len(self.coupling_map[phi_i][phi_j].dims)> 0) or \
+                        ((not self.is2D1D) and hasattr(self.coupling_map[phi_i][phi_j], "dims") and hasattr(terms_u[i][0], "dims") and self.coupling_map[phi_i][phi_j].dims[0] != terms_u[i][0].dim):
                         # 2D1D or 3x2 
                         tmp = self.coupling_map[phi_i][phi_j]
                         if len(self.coupling_map[phi_i][phi_j].dims) == 2: