diff --git a/net/ipv4/fib_trie.c b/net/ipv4/fib_trie.c
index d3dbb48..2b72207 100644
--- a/net/ipv4/fib_trie.c
+++ b/net/ipv4/fib_trie.c
@@ -87,24 +87,38 @@
 
 typedef unsigned int t_key;
 
-#define T_TNODE 0
-#define T_LEAF  1
-#define NODE_TYPE_MASK	0x1UL
-#define NODE_TYPE(node) ((node)->parent & NODE_TYPE_MASK)
+#define IS_TNODE(n) ((n)->bits)
+#define IS_LEAF(n) (!(n)->bits)
 
-#define IS_TNODE(n) (!(n->parent & T_LEAF))
-#define IS_LEAF(n) (n->parent & T_LEAF)
+struct tnode {
+	t_key key;
+	unsigned char bits;		/* 2log(KEYLENGTH) bits needed */
+	unsigned char pos;		/* 2log(KEYLENGTH) bits needed */
+	struct tnode __rcu *parent;
+	union {
+		struct rcu_head rcu;
+		struct tnode *tnode_free;
+	};
+	unsigned int full_children;	/* KEYLENGTH bits needed */
+	unsigned int empty_children;	/* KEYLENGTH bits needed */
+	struct rt_trie_node __rcu *child[0];
+};
 
 struct rt_trie_node {
-	unsigned long parent;
 	t_key key;
+	unsigned char bits;
+	unsigned char pos;
+	struct tnode __rcu *parent;
+	struct rcu_head rcu;
 };
 
 struct leaf {
-	unsigned long parent;
 	t_key key;
-	struct hlist_head list;
+	unsigned char bits;
+	unsigned char pos;
+	struct tnode __rcu *parent;
 	struct rcu_head rcu;
+	struct hlist_head list;
 };
 
 struct leaf_info {
@@ -115,20 +129,6 @@
 	struct rcu_head rcu;
 };
 
-struct tnode {
-	unsigned long parent;
-	t_key key;
-	unsigned char pos;		/* 2log(KEYLENGTH) bits needed */
-	unsigned char bits;		/* 2log(KEYLENGTH) bits needed */
-	unsigned int full_children;	/* KEYLENGTH bits needed */
-	unsigned int empty_children;	/* KEYLENGTH bits needed */
-	union {
-		struct rcu_head rcu;
-		struct tnode *tnode_free;
-	};
-	struct rt_trie_node __rcu *child[0];
-};
-
 #ifdef CONFIG_IP_FIB_TRIE_STATS
 struct trie_use_stats {
 	unsigned int gets;
@@ -176,38 +176,27 @@
 static struct kmem_cache *fn_alias_kmem __read_mostly;
 static struct kmem_cache *trie_leaf_kmem __read_mostly;
 
-/*
- * caller must hold RTNL
- */
-static inline struct tnode *node_parent(const struct rt_trie_node *node)
-{
-	unsigned long parent;
+/* caller must hold RTNL */
+#define node_parent(n) rtnl_dereference((n)->parent)
 
-	parent = rcu_dereference_index_check(node->parent, lockdep_rtnl_is_held());
+/* caller must hold RCU read lock or RTNL */
+#define node_parent_rcu(n) rcu_dereference_rtnl((n)->parent)
 
-	return (struct tnode *)(parent & ~NODE_TYPE_MASK);
-}
-
-/*
- * caller must hold RCU read lock or RTNL
- */
-static inline struct tnode *node_parent_rcu(const struct rt_trie_node *node)
-{
-	unsigned long parent;
-
-	parent = rcu_dereference_index_check(node->parent, rcu_read_lock_held() ||
-							   lockdep_rtnl_is_held());
-
-	return (struct tnode *)(parent & ~NODE_TYPE_MASK);
-}
-
-/* Same as rcu_assign_pointer
- * but that macro() assumes that value is a pointer.
- */
+/* wrapper for rcu_assign_pointer */
 static inline void node_set_parent(struct rt_trie_node *node, struct tnode *ptr)
 {
-	smp_wmb();
-	node->parent = (unsigned long)ptr | NODE_TYPE(node);
+	if (node)
+		rcu_assign_pointer(node->parent, ptr);
+}
+
+#define NODE_INIT_PARENT(n, p) RCU_INIT_POINTER((n)->parent, p)
+
+/* This provides us with the number of children in this node, in the case of a
+ * leaf this will return 0 meaning none of the children are accessible.
+ */
+static inline int tnode_child_length(const struct tnode *tn)
+{
+	return (1ul << tn->bits) & ~(1ul);
 }
 
 /*
@@ -215,7 +204,7 @@
  */
 static inline struct rt_trie_node *tnode_get_child(const struct tnode *tn, unsigned int i)
 {
-	BUG_ON(i >= 1U << tn->bits);
+	BUG_ON(i >= tnode_child_length(tn));
 
 	return rtnl_dereference(tn->child[i]);
 }
@@ -225,16 +214,11 @@
  */
 static inline struct rt_trie_node *tnode_get_child_rcu(const struct tnode *tn, unsigned int i)
 {
-	BUG_ON(i >= 1U << tn->bits);
+	BUG_ON(i >= tnode_child_length(tn));
 
 	return rcu_dereference_rtnl(tn->child[i]);
 }
 
-static inline int tnode_child_length(const struct tnode *tn)
-{
-	return 1 << tn->bits;
-}
-
 static inline t_key mask_pfx(t_key k, unsigned int l)
 {
 	return (l == 0) ? 0 : k >> (KEYLENGTH-l) << (KEYLENGTH-l);
@@ -336,11 +320,6 @@
 
 */
 
-static inline void check_tnode(const struct tnode *tn)
-{
-	WARN_ON(tn && tn->pos+tn->bits > 32);
-}
-
 static const int halve_threshold = 25;
 static const int inflate_threshold = 50;
 static const int halve_threshold_root = 15;
@@ -426,11 +405,20 @@
 	}
 }
 
-static struct leaf *leaf_new(void)
+static struct leaf *leaf_new(t_key key)
 {
 	struct leaf *l = kmem_cache_alloc(trie_leaf_kmem, GFP_KERNEL);
 	if (l) {
-		l->parent = T_LEAF;
+		l->parent = NULL;
+		/* set key and pos to reflect full key value
+		 * any trailing zeros in the key should be ignored
+		 * as the nodes are searched
+		 */
+		l->key = key;
+		l->pos = KEYLENGTH;
+		/* set bits to 0 indicating we are not a tnode */
+		l->bits = 0;
+
 		INIT_HLIST_HEAD(&l->list);
 	}
 	return l;
@@ -451,12 +439,16 @@
 {
 	size_t sz = sizeof(struct tnode) + (sizeof(struct rt_trie_node *) << bits);
 	struct tnode *tn = tnode_alloc(sz);
+	unsigned int shift = pos + bits;
+
+	/* verify bits and pos their msb bits clear and values are valid */
+	BUG_ON(!bits || (shift > KEYLENGTH));
 
 	if (tn) {
-		tn->parent = T_TNODE;
+		tn->parent = NULL;
 		tn->pos = pos;
 		tn->bits = bits;
-		tn->key = key;
+		tn->key = mask_pfx(key, pos);
 		tn->full_children = 0;
 		tn->empty_children = 1<<bits;
 	}
@@ -473,10 +465,7 @@
 
 static inline int tnode_full(const struct tnode *tn, const struct rt_trie_node *n)
 {
-	if (n == NULL || IS_LEAF(n))
-		return 0;
-
-	return ((struct tnode *) n)->pos == tn->pos + tn->bits;
+	return n && IS_TNODE(n) && (n->pos == (tn->pos + tn->bits));
 }
 
 static inline void put_child(struct tnode *tn, int i,
@@ -514,8 +503,7 @@
 	else if (!wasfull && isfull)
 		tn->full_children++;
 
-	if (n)
-		node_set_parent(n, tn);
+	node_set_parent(n, tn);
 
 	rcu_assign_pointer(tn->child[i], n);
 }
@@ -523,7 +511,7 @@
 #define MAX_WORK 10
 static struct rt_trie_node *resize(struct trie *t, struct tnode *tn)
 {
-	int i;
+	struct rt_trie_node *n = NULL;
 	struct tnode *old_tn;
 	int inflate_threshold_use;
 	int halve_threshold_use;
@@ -536,12 +524,11 @@
 		 tn, inflate_threshold, halve_threshold);
 
 	/* No children */
-	if (tn->empty_children == tnode_child_length(tn)) {
-		tnode_free_safe(tn);
-		return NULL;
-	}
+	if (tn->empty_children > (tnode_child_length(tn) - 1))
+		goto no_children;
+
 	/* One child */
-	if (tn->empty_children == tnode_child_length(tn) - 1)
+	if (tn->empty_children == (tnode_child_length(tn) - 1))
 		goto one_child;
 	/*
 	 * Double as long as the resulting node has a number of
@@ -607,11 +594,9 @@
 	 *
 	 */
 
-	check_tnode(tn);
-
 	/* Keep root node larger  */
 
-	if (!node_parent((struct rt_trie_node *)tn)) {
+	if (!node_parent(tn)) {
 		inflate_threshold_use = inflate_threshold_root;
 		halve_threshold_use = halve_threshold_root;
 	} else {
@@ -637,8 +622,6 @@
 		}
 	}
 
-	check_tnode(tn);
-
 	/* Return if at least one inflate is run */
 	if (max_work != MAX_WORK)
 		return (struct rt_trie_node *) tn;
@@ -666,21 +649,16 @@
 
 
 	/* Only one child remains */
-	if (tn->empty_children == tnode_child_length(tn) - 1) {
+	if (tn->empty_children == (tnode_child_length(tn) - 1)) {
+		unsigned long i;
 one_child:
-		for (i = 0; i < tnode_child_length(tn); i++) {
-			struct rt_trie_node *n;
-
-			n = rtnl_dereference(tn->child[i]);
-			if (!n)
-				continue;
-
-			/* compress one level */
-
-			node_set_parent(n, NULL);
-			tnode_free_safe(tn);
-			return n;
-		}
+		for (i = tnode_child_length(tn); !n && i;)
+			n = tnode_get_child(tn, --i);
+no_children:
+		/* compress one level */
+		node_set_parent(n, NULL);
+		tnode_free_safe(tn);
+		return n;
 	}
 	return (struct rt_trie_node *) tn;
 }
@@ -760,8 +738,7 @@
 
 		/* A leaf or an internal node with skipped bits */
 
-		if (IS_LEAF(node) || ((struct tnode *) node)->pos >
-		   tn->pos + tn->bits - 1) {
+		if (IS_LEAF(node) || (node->pos > (tn->pos + tn->bits - 1))) {
 			put_child(tn,
 				tkey_extract_bits(node->key, oldtnode->pos, oldtnode->bits + 1),
 				node);
@@ -958,11 +935,9 @@
 	pos = 0;
 	n = rcu_dereference_rtnl(t->trie);
 
-	while (n != NULL &&  NODE_TYPE(n) == T_TNODE) {
+	while (n && IS_TNODE(n)) {
 		tn = (struct tnode *) n;
 
-		check_tnode(tn);
-
 		if (tkey_sub_equals(tn->key, pos, tn->pos-pos, key)) {
 			pos = tn->pos + tn->bits;
 			n = tnode_get_child_rcu(tn,
@@ -988,7 +963,7 @@
 
 	key = tn->key;
 
-	while (tn != NULL && (tp = node_parent((struct rt_trie_node *)tn)) != NULL) {
+	while (tn != NULL && (tp = node_parent(tn)) != NULL) {
 		cindex = tkey_extract_bits(key, tp->pos, tp->bits);
 		wasfull = tnode_full(tp, tnode_get_child(tp, cindex));
 		tn = (struct tnode *)resize(t, tn);
@@ -996,7 +971,7 @@
 		tnode_put_child_reorg(tp, cindex,
 				      (struct rt_trie_node *)tn, wasfull);
 
-		tp = node_parent((struct rt_trie_node *) tn);
+		tp = node_parent(tn);
 		if (!tp)
 			rcu_assign_pointer(t->trie, (struct rt_trie_node *)tn);
 
@@ -1048,11 +1023,9 @@
 	 * If it doesn't, we need to replace it with a T_TNODE.
 	 */
 
-	while (n != NULL &&  NODE_TYPE(n) == T_TNODE) {
+	while (n && IS_TNODE(n)) {
 		tn = (struct tnode *) n;
 
-		check_tnode(tn);
-
 		if (tkey_sub_equals(tn->key, pos, tn->pos-pos, key)) {
 			tp = tn;
 			pos = tn->pos + tn->bits;
@@ -1087,12 +1060,11 @@
 		insert_leaf_info(&l->list, li);
 		goto done;
 	}
-	l = leaf_new();
+	l = leaf_new(key);
 
 	if (!l)
 		return NULL;
 
-	l->key = key;
 	li = leaf_info_new(plen);
 
 	if (!li) {
@@ -1569,7 +1541,7 @@
 		if (chopped_off <= pn->bits) {
 			cindex &= ~(1 << (chopped_off-1));
 		} else {
-			struct tnode *parent = node_parent_rcu((struct rt_trie_node *) pn);
+			struct tnode *parent = node_parent_rcu(pn);
 			if (!parent)
 				goto failed;
 
@@ -1597,7 +1569,7 @@
  */
 static void trie_leaf_remove(struct trie *t, struct leaf *l)
 {
-	struct tnode *tp = node_parent((struct rt_trie_node *) l);
+	struct tnode *tp = node_parent(l);
 
 	pr_debug("entering trie_leaf_remove(%p)\n", l);
 
@@ -2374,7 +2346,7 @@
 
 	if (IS_TNODE(n)) {
 		struct tnode *tn = (struct tnode *) n;
-		__be32 prf = htonl(mask_pfx(tn->key, tn->pos));
+		__be32 prf = htonl(tn->key);
 
 		seq_indent(seq, iter->depth-1);
 		seq_printf(seq, "  +-- %pI4/%d %d %d %d\n",
