ip: ip_ra_control() rcu fix

commit 66018506e15b (ip: Router Alert RCU conversion) introduced RCU
lookups to ip_call_ra_chain(). It missed proper deinit phase :
When ip_ra_control() deletes an ip_ra_chain, it should make sure
ip_call_ra_chain() users can not start to use socket during the rcu
grace period. It should also delay the sock_put() after the grace
period, or we risk a premature socket freeing and corruptions, as
raw sockets are not rcu protected yet.

This delay avoids using expensive atomic_inc_not_zero() in
ip_call_ra_chain().

Signed-off-by: Eric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/net/ip.h b/include/net/ip.h
index 9982c97..d52f011 100644
--- a/include/net/ip.h
+++ b/include/net/ip.h
@@ -61,7 +61,10 @@
 struct ip_ra_chain {
 	struct ip_ra_chain	*next;
 	struct sock		*sk;
-	void			(*destructor)(struct sock *);
+	union {
+		void			(*destructor)(struct sock *);
+		struct sock		*saved_sk;
+	};
 	struct rcu_head		rcu;
 };
 
diff --git a/net/ipv4/ip_sockglue.c b/net/ipv4/ip_sockglue.c
index 08b9519..47fff52 100644
--- a/net/ipv4/ip_sockglue.c
+++ b/net/ipv4/ip_sockglue.c
@@ -241,9 +241,13 @@
 struct ip_ra_chain *ip_ra_chain;
 static DEFINE_SPINLOCK(ip_ra_lock);
 
-static void ip_ra_free_rcu(struct rcu_head *head)
+
+static void ip_ra_destroy_rcu(struct rcu_head *head)
 {
-	kfree(container_of(head, struct ip_ra_chain, rcu));
+	struct ip_ra_chain *ra = container_of(head, struct ip_ra_chain, rcu);
+
+	sock_put(ra->saved_sk);
+	kfree(ra);
 }
 
 int ip_ra_control(struct sock *sk, unsigned char on,
@@ -264,13 +268,20 @@
 				kfree(new_ra);
 				return -EADDRINUSE;
 			}
+			/* dont let ip_call_ra_chain() use sk again */
+			ra->sk = NULL;
 			rcu_assign_pointer(*rap, ra->next);
 			spin_unlock_bh(&ip_ra_lock);
 
 			if (ra->destructor)
 				ra->destructor(sk);
-			sock_put(sk);
-			call_rcu(&ra->rcu, ip_ra_free_rcu);
+			/*
+			 * Delay sock_put(sk) and kfree(ra) after one rcu grace
+			 * period. This guarantee ip_call_ra_chain() dont need
+			 * to mess with socket refcounts.
+			 */
+			ra->saved_sk = sk;
+			call_rcu(&ra->rcu, ip_ra_destroy_rcu);
 			return 0;
 		}
 	}