RDMA/addr: Use client registration to fix module unload race

Require registration with ib_addr module to prevent caller from
unloading while a callback is in progress.

Signed-off-by: Sean Hefty <sean.hefty@intel.com>
Signed-off-by: Roland Dreier <rolandd@cisco.com>
diff --git a/drivers/infiniband/core/addr.c b/drivers/infiniband/core/addr.c
index 60d3fbd..e11187e 100644
--- a/drivers/infiniband/core/addr.c
+++ b/drivers/infiniband/core/addr.c
@@ -47,6 +47,7 @@
 	struct sockaddr src_addr;
 	struct sockaddr dst_addr;
 	struct rdma_dev_addr *addr;
+	struct rdma_addr_client *client;
 	void *context;
 	void (*callback)(int status, struct sockaddr *src_addr,
 			 struct rdma_dev_addr *addr, void *context);
@@ -61,6 +62,26 @@
 static DECLARE_WORK(work, process_req, NULL);
 static struct workqueue_struct *addr_wq;
 
+void rdma_addr_register_client(struct rdma_addr_client *client)
+{
+	atomic_set(&client->refcount, 1);
+	init_completion(&client->comp);
+}
+EXPORT_SYMBOL(rdma_addr_register_client);
+
+static inline void put_client(struct rdma_addr_client *client)
+{
+	if (atomic_dec_and_test(&client->refcount))
+		complete(&client->comp);
+}
+
+void rdma_addr_unregister_client(struct rdma_addr_client *client)
+{
+	put_client(client);
+	wait_for_completion(&client->comp);
+}
+EXPORT_SYMBOL(rdma_addr_unregister_client);
+
 int rdma_copy_addr(struct rdma_dev_addr *dev_addr, struct net_device *dev,
 		     const unsigned char *dst_dev_addr)
 {
@@ -229,6 +250,7 @@
 		list_del(&req->list);
 		req->callback(req->status, &req->src_addr, req->addr,
 			      req->context);
+		put_client(req->client);
 		kfree(req);
 	}
 }
@@ -264,7 +286,8 @@
 	return ret;
 }
 
-int rdma_resolve_ip(struct sockaddr *src_addr, struct sockaddr *dst_addr,
+int rdma_resolve_ip(struct rdma_addr_client *client,
+		    struct sockaddr *src_addr, struct sockaddr *dst_addr,
 		    struct rdma_dev_addr *addr, int timeout_ms,
 		    void (*callback)(int status, struct sockaddr *src_addr,
 				     struct rdma_dev_addr *addr, void *context),
@@ -285,6 +308,8 @@
 	req->addr = addr;
 	req->callback = callback;
 	req->context = context;
+	req->client = client;
+	atomic_inc(&client->refcount);
 
 	src_in = (struct sockaddr_in *) &req->src_addr;
 	dst_in = (struct sockaddr_in *) &req->dst_addr;
@@ -305,6 +330,7 @@
 		break;
 	default:
 		ret = req->status;
+		atomic_dec(&client->refcount);
 		kfree(req);
 		break;
 	}