librdmacm: fix race in ucma_init()
[mirror/winof/.git] / ulp / librdmacm / src / cma.cpp
index 934e2f2..7d31c30 100644 (file)
 #include "cma.h"\r
 #include "..\..\..\etc\user\comp_channel.cpp"\r
 \r
-static struct ibv_windata windata;\r
+static struct ibvw_windata windata;\r
 \r
 enum cma_state\r
 {\r
        cma_idle,\r
        cma_listening,\r
        cma_get_request,\r
+       cma_addr_bind,\r
        cma_addr_resolve,\r
        cma_route_resolve,\r
        cma_passive_connect,\r
@@ -55,7 +56,8 @@ enum cma_state
        cma_connected,\r
        cma_active_disconnect,\r
        cma_passive_disconnect,\r
-       cma_disconnected\r
+       cma_disconnected,\r
+       cma_destroying\r
 };\r
 \r
 #define CMA_DEFAULT_BACKLOG            16\r
@@ -67,6 +69,7 @@ struct cma_id_private
        struct cma_device                       *cma_dev;\r
        int                                                     backlog;\r
        int                                                     index;\r
+       volatile LONG                           refcnt;\r
        struct rdma_cm_id                       **req_list;\r
 };\r
 \r
@@ -88,63 +91,49 @@ struct cma_event {
 static struct cma_device *cma_dev_array;\r
 static int cma_dev_cnt;\r
 \r
-static void ucma_cleanup(void)\r
-{\r
-       if (cma_dev_cnt > 0) {\r
-               while (cma_dev_cnt > 0) {\r
-                       ibv_close_device(cma_dev_array[--cma_dev_cnt].verbs);\r
-               }\r
-               delete cma_dev_array;\r
-               cma_dev_cnt = 0;\r
-       }\r
-       if (windata.prov != NULL) {\r
-               ibv_release_windata(&windata, IBV_WINDATA_VERSION);\r
-               windata.prov = NULL;\r
-       }\r
-}\r
-\r
 static int ucma_init(void)\r
 {\r
        struct ibv_device **dev_list = NULL;\r
        struct cma_device *cma_dev;\r
        struct ibv_device_attr attr;\r
-       int i, ret;\r
+       int i, ret, dev_cnt;\r
 \r
        EnterCriticalSection(&lock);\r
-       if (cma_dev_cnt > 0) {\r
+       if (cma_dev_cnt) {\r
                goto out;\r
        }\r
 \r
-       ret = ibv_get_windata(&windata, IBV_WINDATA_VERSION);\r
+       ret = ibvw_get_windata(&windata, IBVW_WINDATA_VERSION);\r
        if (ret) {\r
-               goto err;\r
+               goto err1;\r
        }\r
 \r
-       dev_list = ibv_get_device_list(&cma_dev_cnt);\r
+       dev_list = ibv_get_device_list(&dev_cnt);\r
        if (dev_list == NULL) {\r
                ret = -1;\r
-               goto err;\r
+               goto err2;\r
        }\r
 \r
-       cma_dev_array = new struct cma_device[cma_dev_cnt];\r
+       cma_dev_array = new struct cma_device[dev_cnt];\r
        if (cma_dev_array == NULL) {\r
                ret = -1;\r
-               goto err;\r
+               goto err3;\r
        }\r
 \r
-       for (i = 0; dev_list[i]; ++i) {\r
+       for (i = 0; dev_list[i];) {\r
                cma_dev = &cma_dev_array[i];\r
 \r
                cma_dev->guid = ibv_get_device_guid(dev_list[i]);\r
                cma_dev->verbs = ibv_open_device(dev_list[i]);\r
                if (cma_dev->verbs == NULL) {\r
                        ret = -1;\r
-                       goto err;\r
+                       goto err4;\r
                }\r
 \r
+               ++i;\r
                ret = ibv_query_device(cma_dev->verbs, &attr);\r
                if (ret) {\r
-                       goto err;\r
+                       goto err4;\r
                }\r
 \r
                cma_dev->port_cnt = attr.phys_port_cnt;\r
@@ -152,16 +141,22 @@ static int ucma_init(void)
                cma_dev->max_responder_resources = (uint8_t) attr.max_qp_rd_atom;\r
        }\r
        ibv_free_device_list(dev_list);\r
+       cma_dev_cnt = dev_cnt;\r
 out:\r
        LeaveCriticalSection(&lock);\r
        return 0;\r
 \r
-err:\r
-       ucma_cleanup();\r
-       LeaveCriticalSection(&lock);\r
-       if (dev_list) {\r
-               ibv_free_device_list(dev_list);\r
+err4:\r
+       while (i) {\r
+               ibv_close_device(cma_dev_array[--i].verbs);\r
        }\r
+       delete cma_dev_array;\r
+err3:\r
+       ibv_free_device_list(dev_list);\r
+err2:\r
+       ibvw_release_windata(&windata, IBVW_WINDATA_VERSION);\r
+err1:\r
+       LeaveCriticalSection(&lock);\r
        return ret;\r
 }\r
 \r
@@ -241,6 +236,7 @@ int rdma_create_id(struct rdma_event_channel *channel,
        }\r
 \r
        RtlZeroMemory(id_priv, sizeof(struct cma_id_private));\r
+       id_priv->refcnt = 1;\r
        id_priv->id.context = context;\r
        id_priv->id.channel = channel;\r
        id_priv->id.ps = ps;\r
@@ -267,6 +263,7 @@ static void ucma_destroy_listen(struct cma_id_private *id_priv)
 {\r
        while (--id_priv->backlog >= 0) {\r
                if (id_priv->req_list[id_priv->backlog] != NULL) {\r
+                       InterlockedDecrement(&id_priv->refcnt);\r
                        rdma_destroy_id(id_priv->req_list[id_priv->backlog]);\r
                }\r
        }\r
@@ -280,13 +277,20 @@ int rdma_destroy_id(struct rdma_cm_id *id)
        struct cma_id_private *id_priv;\r
 \r
        id_priv = CONTAINING_RECORD(id, struct cma_id_private, id);\r
+\r
+       EnterCriticalSection(&lock);\r
+       id_priv->state = cma_destroying;\r
+       LeaveCriticalSection(&lock);\r
+\r
        if (id->ps == RDMA_PS_TCP) {\r
                id->ep.connect->CancelOverlappedRequests();\r
        } else {\r
                id->ep.datagram->CancelOverlappedRequests();\r
        }\r
 \r
-       CompChannelRemoveEntry(&id->channel->channel, &id->comp_entry);\r
+       if (CompEntryCancel(&id->comp_entry) != NULL) {\r
+               InterlockedDecrement(&id_priv->refcnt);\r
+       }\r
 \r
        if (id_priv->backlog > 0) {\r
                ucma_destroy_listen(id_priv);\r
@@ -298,6 +302,10 @@ int rdma_destroy_id(struct rdma_cm_id *id)
                id_priv->id.ep.datagram->Release();\r
        }\r
 \r
+       InterlockedDecrement(&id_priv->refcnt);\r
+       while (id_priv->refcnt) {\r
+               Sleep(0);\r
+       }\r
        delete id_priv;\r
        return 0;\r
 }\r
@@ -407,6 +415,7 @@ static int ucma_query_datagram(struct rdma_cm_id *id, struct rdma_ud_param *para
 __declspec(dllexport)\r
 int rdma_bind_addr(struct rdma_cm_id *id, struct sockaddr *addr)\r
 {\r
+       struct cma_id_private *id_priv;\r
        HRESULT hr;\r
 \r
        if (id->ps == RDMA_PS_TCP) {\r
@@ -421,6 +430,10 @@ int rdma_bind_addr(struct rdma_cm_id *id, struct sockaddr *addr)
                }\r
        }\r
 \r
+       if (SUCCEEDED(hr)) {\r
+               id_priv = CONTAINING_RECORD(id, struct cma_id_private, id);\r
+               id_priv->state = cma_addr_bind;\r
+       }\r
        return hr;\r
 }\r
 \r
@@ -434,34 +447,37 @@ int rdma_resolve_addr(struct rdma_cm_id *id, struct sockaddr *src_addr,
        DWORD size;\r
        HRESULT hr;\r
 \r
-       if (src_addr == NULL) {\r
-               if (id->ps == RDMA_PS_TCP) {\r
-                       s = socket(dst_addr->sa_family, SOCK_STREAM, IPPROTO_TCP);\r
-               } else {\r
-                       s = socket(dst_addr->sa_family, SOCK_DGRAM, IPPROTO_UDP);\r
-               }\r
-               if (s == INVALID_SOCKET) {\r
-                       return WSAGetLastError();\r
+       id_priv = CONTAINING_RECORD(id, struct cma_id_private, id);\r
+       if (id_priv->state == cma_idle) {\r
+               if (src_addr == NULL) {\r
+                       if (id->ps == RDMA_PS_TCP) {\r
+                               s = socket(dst_addr->sa_family, SOCK_STREAM, IPPROTO_TCP);\r
+                       } else {\r
+                               s = socket(dst_addr->sa_family, SOCK_DGRAM, IPPROTO_UDP);\r
+                       }\r
+                       if (s == INVALID_SOCKET) {\r
+                               return WSAGetLastError();\r
+                       }\r
+\r
+                       hr = WSAIoctl(s, SIO_ROUTING_INTERFACE_QUERY, dst_addr, ucma_addrlen(dst_addr),\r
+                                                 &addr, sizeof addr, &size, NULL, NULL);\r
+                       closesocket(s);\r
+                       if (FAILED(hr)) {\r
+                               return WSAGetLastError();\r
+                       }\r
+                       src_addr = &addr.Sa;\r
                }\r
 \r
-               hr = WSAIoctl(s, SIO_ROUTING_INTERFACE_QUERY, dst_addr, ucma_addrlen(dst_addr),\r
-                                         &addr, sizeof addr, &size, NULL, NULL);\r
-               closesocket(s);\r
+               hr = rdma_bind_addr(id, src_addr);\r
                if (FAILED(hr)) {\r
-                       return WSAGetLastError();\r
+                       return hr;\r
                }\r
-               src_addr = &addr.Sa;\r
-       }\r
-\r
-       hr = rdma_bind_addr(id, src_addr);\r
-       if (FAILED(hr)) {\r
-               return hr;\r
        }\r
 \r
        RtlCopyMemory(&id->route.addr.dst_addr, dst_addr, ucma_addrlen(dst_addr));\r
-       id_priv = CONTAINING_RECORD(id, struct cma_id_private, id);\r
        id_priv->state = cma_addr_resolve;\r
 \r
+       id_priv->refcnt++;\r
        CompEntryPost(&id->comp_entry);\r
        return 0;\r
 }\r
@@ -473,7 +489,16 @@ int rdma_resolve_route(struct rdma_cm_id *id, int timeout_ms)
        IBAT_PATH_BLOB path;\r
        HRESULT hr;\r
 \r
-       hr = IBAT::Resolve(&id->route.addr.src_addr, &id->route.addr.dst_addr, &path);\r
+       do {\r
+               hr = IBAT::Resolve(&id->route.addr.src_addr, &id->route.addr.dst_addr, &path);\r
+               if (hr != E_PENDING) {\r
+                       break;\r
+               }\r
+               timeout_ms -= 10;\r
+               if (timeout_ms > 0)\r
+                       Sleep(10);\r
+       } while (timeout_ms > 0);\r
+\r
        if (FAILED(hr)) {\r
                return hr;\r
        }\r
@@ -488,6 +513,7 @@ int rdma_resolve_route(struct rdma_cm_id *id, int timeout_ms)
        id_priv = CONTAINING_RECORD(id, struct cma_id_private, id);\r
        id_priv->state = cma_route_resolve;\r
 \r
+       id_priv->refcnt++;\r
        CompEntryPost(&id->comp_entry);\r
        return 0;\r
 }\r
@@ -613,9 +639,13 @@ int rdma_connect(struct rdma_cm_id *id, struct rdma_conn_param *conn_param)
        }\r
 \r
        id_priv->state = cma_active_connect;\r
+       id_priv->refcnt++;\r
+       id->comp_entry.Busy = 1;\r
        hr = id->ep.connect->Connect(id->qp->conn_handle, &id->route.addr.dst_addr,\r
                                                                 &attr, &id->comp_entry.Overlap);\r
        if (FAILED(hr) && hr != WV_IO_PENDING) {\r
+               id_priv->refcnt--;\r
+               id->comp_entry.Busy = 0;\r
                id_priv->state = cma_route_resolve;\r
                return hr;\r
        }\r
@@ -625,19 +655,28 @@ int rdma_connect(struct rdma_cm_id *id, struct rdma_conn_param *conn_param)
 \r
 static int ucma_get_request(struct cma_id_private *listen, int index)\r
 {\r
-       struct cma_id_private *id_priv;\r
+       struct cma_id_private *id_priv = NULL;\r
        HRESULT hr;\r
 \r
+       EnterCriticalSection(&lock);\r
+       if (listen->state != cma_listening) {\r
+               hr = WV_INVALID_PARAMETER;\r
+               goto err1;\r
+       }\r
+\r
+       InterlockedIncrement(&listen->refcnt);\r
        hr = rdma_create_id(listen->id.channel, &listen->req_list[index],\r
                                                listen, listen->id.ps);\r
        if (FAILED(hr)) {\r
-               return hr;\r
+               goto err2;\r
        }\r
 \r
        id_priv = CONTAINING_RECORD(listen->req_list[index], struct cma_id_private, id);\r
        id_priv->index = index;\r
        id_priv->state = cma_get_request;\r
 \r
+       id_priv->refcnt++;\r
+       id_priv->id.comp_entry.Busy = 1;\r
        if (listen->id.ps == RDMA_PS_TCP) {\r
                hr = listen->id.ep.connect->GetRequest(id_priv->id.ep.connect,\r
                                                                                           &id_priv->id.comp_entry.Overlap);\r
@@ -646,10 +685,22 @@ static int ucma_get_request(struct cma_id_private *listen, int index)
                                                                                                &id_priv->id.comp_entry.Overlap);\r
        }\r
        if (FAILED(hr) && hr != WV_IO_PENDING) {\r
-               return hr;\r
+               id_priv->id.comp_entry.Busy = 0;\r
+               id_priv->refcnt--;\r
+               goto err2;\r
        }\r
+       LeaveCriticalSection(&lock);\r
 \r
        return 0;\r
+\r
+err2:\r
+       InterlockedDecrement(&listen->refcnt);\r
+err1:\r
+       LeaveCriticalSection(&lock);\r
+       if (id_priv != NULL) {\r
+               rdma_destroy_id(&id_priv->id);\r
+       }\r
+       return hr;\r
 }\r
 \r
 __declspec(dllexport)\r
@@ -712,9 +763,13 @@ int rdma_accept(struct rdma_cm_id *id, struct rdma_conn_param *conn_param)
        }\r
 \r
        id_priv->state = cma_accepting;\r
+       id_priv->refcnt++;\r
+       id->comp_entry.Busy = 1;\r
        hr = id->ep.connect->Accept(id->qp->conn_handle, &attr,\r
                                                                &id->comp_entry.Overlap);\r
        if (FAILED(hr) && hr != WV_IO_PENDING) {\r
+               id_priv->refcnt--;\r
+               id->comp_entry.Busy = 0;\r
                id_priv->state = cma_disconnected;\r
                return hr;\r
        }\r
@@ -756,7 +811,7 @@ int rdma_disconnect(struct rdma_cm_id *id)
        } else {\r
                id_priv->state = cma_disconnected;\r
        }\r
-       hr = id->ep.connect->Disconnect();\r
+       hr = id->ep.connect->Disconnect(NULL);\r
        if (FAILED(hr)) {\r
                return hr;\r
        }\r
@@ -768,30 +823,41 @@ __declspec(dllexport)
 int rdma_ack_cm_event(struct rdma_cm_event *event)\r
 {\r
        struct cma_event *evt;\r
+       struct cma_id_private *listen;\r
 \r
        evt = CONTAINING_RECORD(event, struct cma_event, event);\r
+       InterlockedDecrement(&evt->id_priv->refcnt);\r
+       if (evt->event.listen_id) {\r
+               listen = CONTAINING_RECORD(evt->event.listen_id, struct cma_id_private, id);\r
+               InterlockedDecrement(&listen->refcnt);\r
+       }\r
        delete evt;\r
        return 0;\r
 }\r
 \r
 static int ucma_process_conn_req(struct cma_event *event)\r
 {\r
-       struct cma_id_private *listen;\r
+       struct cma_id_private *listen, *id_priv;\r
        struct cma_event_channel *chan;\r
 \r
        listen = (struct cma_id_private *) event->id_priv->id.context;\r
-       ucma_get_request(listen, event->id_priv->index);\r
+       id_priv = event->id_priv;\r
+\r
+       ucma_get_request(listen, id_priv->index);\r
 \r
        if (SUCCEEDED(event->event.status)) {\r
-               event->event.status = ucma_query_connect(&event->id_priv->id,\r
+               event->event.status = ucma_query_connect(&id_priv->id,\r
                                                                                                 &event->event.param.conn);\r
        }\r
 \r
        if (SUCCEEDED(event->event.status)) {\r
                event->event.event = RDMA_CM_EVENT_CONNECT_REQUEST;\r
-               event->id_priv->state = cma_passive_connect;\r
+               id_priv->state = cma_passive_connect;\r
+               event->event.listen_id = &listen->id;\r
        } else {\r
-               rdma_destroy_id(&event->id_priv->id);\r
+               InterlockedDecrement(&listen->refcnt);\r
+               InterlockedDecrement(&id_priv->refcnt);\r
+               rdma_destroy_id(&id_priv->id);\r
        }\r
 \r
        return event->event.status;\r
@@ -811,9 +877,11 @@ static int ucma_process_conn_resp(struct cma_event *event)
        event->id_priv->state = cma_accepting;\r
 \r
        id = &event->id_priv->id;\r
+       id->comp_entry.Busy = 1;\r
        hr = id->ep.connect->Accept(id->qp->conn_handle, &attr,\r
                                                                &id->comp_entry.Overlap);\r
        if (FAILED(hr) && hr != WV_IO_PENDING) {\r
+               id->comp_entry.Busy = 0;\r
                event->event.status = hr;\r
                goto err;\r
        }\r
@@ -841,6 +909,8 @@ static void ucma_process_establish(struct cma_event *event)
                event->event.event = RDMA_CM_EVENT_ESTABLISHED;\r
 \r
                id_priv->state = cma_connected;\r
+               InterlockedIncrement(&id_priv->refcnt);\r
+               id_priv->id.comp_entry.Busy = 1;\r
                id_priv->id.ep.connect->NotifyDisconnect(&id_priv->id.comp_entry.Overlap);\r
        } else {\r
                event->event.event = RDMA_CM_EVENT_CONNECT_ERROR;\r
@@ -850,13 +920,25 @@ static void ucma_process_establish(struct cma_event *event)
 \r
 static int ucma_process_event(struct cma_event *event)\r
 {\r
+       struct cma_id_private *listen, *id_priv;\r
        WV_CONNECT_ATTRIBUTES attr;\r
        HRESULT hr = 0;\r
 \r
-       switch (event->id_priv->state) {\r
+       id_priv = event->id_priv;\r
+\r
+       EnterCriticalSection(&lock);\r
+       switch (id_priv->state) {\r
        case cma_get_request:\r
-               hr = ucma_process_conn_req(event);\r
-               break;\r
+               listen = (struct cma_id_private *) id_priv->id.context;\r
+               if (listen->state != cma_listening) {\r
+                       InterlockedDecrement(&id_priv->refcnt);\r
+                       hr = WV_CANCELLED;\r
+                       break;\r
+               }\r
+\r
+               listen->req_list[id_priv->index] = NULL;\r
+               LeaveCriticalSection(&lock);\r
+               return ucma_process_conn_req(event);\r
        case cma_addr_resolve:\r
                event->event.event = RDMA_CM_EVENT_ADDR_RESOLVED;\r
                break;\r
@@ -871,15 +953,17 @@ static int ucma_process_event(struct cma_event *event)
                break;\r
        case cma_connected:\r
                event->event.event = RDMA_CM_EVENT_DISCONNECTED;\r
-               event->id_priv->state = cma_passive_disconnect;\r
+               id_priv->state = cma_passive_disconnect;\r
                break;\r
        case cma_active_disconnect:\r
                event->event.event = RDMA_CM_EVENT_DISCONNECTED;\r
-               event->id_priv->state = cma_disconnected;\r
+               id_priv->state = cma_disconnected;\r
                break;\r
        default:\r
-               return -1;\r
+               InterlockedDecrement(&id_priv->refcnt);\r
+               hr = WV_CANCELLED;\r
        }\r
+       LeaveCriticalSection(&lock);\r
 \r
        return hr;\r
 }\r
@@ -903,6 +987,7 @@ int rdma_get_cm_event(struct rdma_event_channel *channel,
 \r
                ret = CompChannelPoll(&channel->channel, &entry);\r
                if (ret) {\r
+                       delete evt;\r
                        return ret;\r
                }\r
 \r