ND Provider source.
[mirror/winof/.git] / ulp / nd / user / NdProv.cpp
1 /*\r
2  * Copyright (c) 2008 Microsoft Corporation.  All rights reserved.\r
3  *\r
4  * This software is available to you under the OpenIB.org BSD license\r
5  * below:\r
6  *\r
7  *     Redistribution and use in source and binary forms, with or\r
8  *     without modification, are permitted provided that the following\r
9  *     conditions are met:\r
10  *\r
11  *      - Redistributions of source code must retain the above\r
12  *        copyright notice, this list of conditions and the following\r
13  *        disclaimer.\r
14  *\r
15  *      - Redistributions in binary form must reproduce the above\r
16  *        copyright notice, this list of conditions and the following\r
17  *        disclaimer in the documentation and/or other materials\r
18  *        provided with the distribution.\r
19  *\r
20  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,\r
21  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\r
22  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\r
23  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS\r
24  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN\r
25  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN\r
26  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\r
27  * SOFTWARE.\r
28  *\r
29  * $Id:$\r
30  */\r
31 \r
32 #include <tchar.h>\r
33 #include <ndspi.h>\r
34 #include <iba/ib_at_ioctl.h>\r
35 #include <complib/cl_types.h>\r
36 #include <complib/cl_ioctl.h>\r
37 #pragma warning( push, 3 )\r
38 #include <unknwn.h>\r
39 #include <assert.h>\r
40 #include <ws2tcpip.h>\r
41 #include <winioctl.h>\r
42 #include <limits.h>\r
43 #include <ws2spi.h>\r
44 #pragma warning( pop )\r
45 #include "ndprov.h"\r
46 #include "ndadapter.h"\r
47 \r
48 #if defined(EVENT_TRACING)\r
49 #ifdef offsetof\r
50 #undef offsetof\r
51 #endif\r
52 #include "NdProv.tmh"\r
53 #endif\r
54 \r
55 #include "nddebug.h"\r
56 \r
57 uint32_t g_nd_dbg_level = TRACE_LEVEL_ERROR;\r
58 /* WPP doesn't want here literals! */\r
59 uint32_t g_nd_dbg_flags = 0x80000001; /* ND_DBG_ERROR | ND_DBG_NDI; */\r
60 \r
61 namespace NetworkDirect\r
62 {\r
63 \r
64     static LONG gnRef = 0;\r
65 \r
66     CProvider::CProvider() :\r
67         m_nRef( 1 )\r
68     {\r
69         InterlockedIncrement( &gnRef );\r
70     }\r
71 \r
72     CProvider::~CProvider()\r
73     {\r
74         InterlockedDecrement( &gnRef );\r
75     }\r
76 \r
77     HRESULT CProvider::QueryInterface(\r
78         const IID &riid,\r
79         void **ppObject )\r
80     {\r
81         if( IsEqualIID( riid, IID_IUnknown ) )\r
82         {\r
83             *ppObject = this;\r
84             return S_OK;\r
85         }\r
86 \r
87         if( IsEqualIID( riid, IID_INDProvider ) )\r
88         {\r
89             *ppObject = this;\r
90             return S_OK;\r
91         }\r
92 \r
93         return E_NOINTERFACE;\r
94     }\r
95 \r
96     ULONG CProvider::AddRef()\r
97     {\r
98         return InterlockedIncrement( &m_nRef );\r
99     }\r
100 \r
101     ULONG CProvider::Release()\r
102     {\r
103         ULONG ref = InterlockedDecrement( &m_nRef );\r
104         if( ref == 0 )\r
105             delete this;\r
106 \r
107         return ref;\r
108     }\r
109 \r
110     HRESULT CProvider::QueryAddressList(\r
111             __out_bcount_part_opt(*pBufferSize, *pBufferSize) SOCKET_ADDRESS_LIST* pAddressList,\r
112             __inout SIZE_T* pBufferSize )\r
113     {\r
114         ND_ENTER( ND_DBG_NDI );\r
115 \r
116         HANDLE hIbatDev = CreateFileW( IBAT_WIN32_NAME,\r
117             MAXIMUM_ALLOWED, 0, NULL,\r
118             OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL );\r
119         if( hIbatDev == INVALID_HANDLE_VALUE )\r
120             return ND_NO_MEMORY;\r
121 \r
122         IOCTL_IBAT_IP_ADDRESSES_IN addrIn;\r
123 \r
124         addrIn.Version = IBAT_IOCTL_VERSION;\r
125         addrIn.PortGuid = 0;\r
126 \r
127         DWORD size = sizeof(IOCTL_IBAT_IP_ADDRESSES_OUT);\r
128         IOCTL_IBAT_IP_ADDRESSES_OUT *pAddrOut;\r
129         do\r
130         {\r
131             pAddrOut = (IOCTL_IBAT_IP_ADDRESSES_OUT*)HeapAlloc(\r
132                 GetProcessHeap(),\r
133                 0,\r
134                 size );\r
135             if( !pAddrOut )\r
136             {\r
137                 //AL_PRINT( TRACE_LEVEL_ERROR, AL_DBG_ERROR,\r
138                 //    ("Failed to allocate output buffer.\n") );\r
139                 return ND_NO_MEMORY;\r
140             }\r
141 \r
142             if( !DeviceIoControl( hIbatDev, IOCTL_IBAT_IP_ADDRESSES,\r
143                 &addrIn, sizeof(addrIn), pAddrOut, size, &size, NULL ) )\r
144             {\r
145                 HeapFree( GetProcessHeap(), 0, pAddrOut );\r
146                 //AL_PRINT( TRACE_LEVEL_ERROR, AL_DBG_ERROR,\r
147                 //    ("IOCTL_IBAT_IP_ADDRESSES failed (%x).\n", GetLastError()) );\r
148                 return ND_UNSUCCESSFUL;\r
149             }\r
150 \r
151             if( pAddrOut->Size > size )\r
152             {\r
153                 size = pAddrOut->Size;\r
154                 HeapFree( GetProcessHeap(), 0, pAddrOut );\r
155                 pAddrOut = NULL;\r
156             }\r
157 \r
158         } while( !pAddrOut );\r
159 \r
160         CloseHandle( hIbatDev );\r
161 \r
162         //\r
163         // Note: the required size computed is a few bytes larger than necessary, \r
164         // but that keeps the code clean.\r
165         //\r
166         SIZE_T size_req = sizeof(SOCKET_ADDRESS_LIST);\r
167 \r
168         switch( pAddrOut->AddressCount )\r
169         {\r
170         case 0:\r
171             break;\r
172 \r
173         default:\r
174             size_req += (pAddrOut->AddressCount - 1) *\r
175                 (sizeof(SOCKET_ADDRESS) + sizeof(SOCKADDR));\r
176             /* Fall through. */\r
177             __fallthrough;\r
178 \r
179         case 1:\r
180             /* Add the space for the first address. */\r
181             size_req += sizeof(SOCKADDR);\r
182             break;\r
183         }\r
184 \r
185         if( size_req > *pBufferSize )\r
186         {\r
187             HeapFree( GetProcessHeap(), 0, pAddrOut );\r
188             *pBufferSize = size_req;\r
189             return ND_BUFFER_OVERFLOW;\r
190         }\r
191 \r
192         ZeroMemory( pAddressList, *pBufferSize );\r
193 \r
194         /* We store the array of addresses after the last address pointer:\r
195         *      iAddressCount\r
196         *      Address[0]; <-- points to sockaddr[0]\r
197         *      Address[1]; <-- points to sockaddr[1]\r
198         *      ...\r
199         *      Address[n-1]; <-- points to sockaddr[n-1]\r
200         *      sockaddr[0];\r
201         *      sockaddr[1];\r
202         *      ...\r
203         *      sockaddr[n-1]\r
204         */\r
205         BYTE* pBuf = (BYTE*)(&(pAddressList->Address[pAddrOut->AddressCount]));\r
206         *pBufferSize = size_req;\r
207 \r
208         pAddressList->iAddressCount = 0;\r
209         for( LONG i = 0; i < pAddrOut->AddressCount; i++ )\r
210         {\r
211             pAddressList->Address[pAddressList->iAddressCount].lpSockaddr =\r
212                 (LPSOCKADDR)pBuf;\r
213 \r
214             switch( pAddrOut->Address[i].IpVersion )\r
215             {\r
216             case 4:\r
217                 {\r
218                     struct sockaddr_in* pAddr4 = ((struct sockaddr_in*)pBuf);\r
219                     pAddr4->sin_family = AF_INET;\r
220                     pAddr4->sin_addr.s_addr =\r
221                         *((u_long*)&pAddrOut->Address[i].Address[12]);\r
222                     pAddressList->Address[pAddressList->iAddressCount].iSockaddrLength =\r
223                         sizeof(struct sockaddr_in);\r
224                 }\r
225                 break;\r
226 \r
227             case 6:\r
228                 {\r
229                     struct sockaddr_in6* pAddr6 = ((struct sockaddr_in6*)pBuf);\r
230                     pAddr6->sin6_family = AF_INET6;\r
231                     CopyMemory(\r
232                         &pAddr6->sin6_addr,\r
233                         pAddrOut->Address[i].Address,\r
234                         sizeof(pAddr6->sin6_addr) );\r
235                     pAddressList->Address[pAddressList->iAddressCount].iSockaddrLength =\r
236                         sizeof(struct sockaddr_in6);\r
237                 }\r
238                 break;\r
239 \r
240             default:\r
241                 continue;\r
242             }\r
243 \r
244             pBuf += pAddressList->Address[pAddressList->iAddressCount++].iSockaddrLength;\r
245         }\r
246 \r
247         HeapFree( GetProcessHeap(), 0, pAddrOut );\r
248 \r
249         return S_OK;\r
250     }\r
251 \r
252     HRESULT CProvider::OpenAdapter(\r
253             __in_bcount(AddressLength) const struct sockaddr* pAddress,\r
254             __in SIZE_T AddressLength,\r
255             __deref_out INDAdapter** ppAdapter )\r
256     {\r
257         ND_ENTER( ND_DBG_NDI );\r
258 \r
259         if( AddressLength < sizeof(struct sockaddr) )\r
260             return ND_INVALID_ADDRESS;\r
261 \r
262         IOCTL_IBAT_IP_TO_PORT_IN in;\r
263         in.Version = IBAT_IOCTL_VERSION;\r
264 \r
265         switch( pAddress->sa_family )\r
266         {\r
267         case AF_INET:\r
268             if( AddressLength < sizeof(struct sockaddr_in) )\r
269                 return ND_INVALID_ADDRESS;\r
270             in.Address.IpVersion = 4;\r
271             RtlCopyMemory(\r
272                 &in.Address.Address[12],\r
273                 &((struct sockaddr_in*)pAddress)->sin_addr,\r
274                 sizeof( ((struct sockaddr_in*)pAddress)->sin_addr ) );\r
275             break;\r
276 \r
277         case AF_INET6:\r
278             if( AddressLength < sizeof(struct sockaddr_in6) )\r
279                 return ND_INVALID_ADDRESS;\r
280             in.Address.IpVersion = 6;\r
281             RtlCopyMemory(\r
282                 in.Address.Address,\r
283                 &((struct sockaddr_in6*)pAddress)->sin6_addr,\r
284                 sizeof(in.Address.Address) );\r
285             break;\r
286 \r
287         default:\r
288             return ND_INVALID_ADDRESS;\r
289         }\r
290 \r
291         HANDLE hIbatDev = CreateFileW( IBAT_WIN32_NAME,\r
292             MAXIMUM_ALLOWED, 0, NULL,\r
293             OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL );\r
294         if( hIbatDev == INVALID_HANDLE_VALUE )\r
295             return ND_NO_MEMORY;\r
296 \r
297         IBAT_PORT_RECORD out;\r
298         DWORD size;\r
299         BOOL fSuccess = DeviceIoControl( hIbatDev, IOCTL_IBAT_IP_TO_PORT,\r
300             &in, sizeof(in), &out, sizeof(out), &size, NULL );\r
301         \r
302         CloseHandle( hIbatDev );\r
303         if( !fSuccess || size == 0 )\r
304             return ND_INVALID_ADDRESS;\r
305 \r
306         return CAdapter::Create( this, pAddress, &out, ppAdapter );\r
307     }\r
308 \r
309     CClassFactory::CClassFactory(void) :\r
310         m_nRef( 1 )\r
311     {\r
312         InterlockedIncrement( &gnRef );\r
313     }\r
314 \r
315     CClassFactory::~CClassFactory(void)\r
316     {\r
317         InterlockedDecrement( &gnRef );\r
318     }\r
319 \r
320     HRESULT CClassFactory::QueryInterface(\r
321         REFIID riid,\r
322         void** ppObject )\r
323     {\r
324         if( IsEqualIID( riid, IID_IUnknown ) )\r
325         {\r
326             *ppObject = this;\r
327             return S_OK;\r
328         }\r
329         if( IsEqualIID( riid, IID_IClassFactory ) )\r
330         {\r
331             *ppObject = this;\r
332             return S_OK;\r
333         }\r
334 \r
335         return E_NOINTERFACE;\r
336     }\r
337 \r
338     ULONG CClassFactory::AddRef()\r
339     {\r
340         return InterlockedIncrement( &m_nRef );\r
341     }\r
342 \r
343     ULONG CClassFactory::Release()\r
344     {\r
345         ULONG ref = InterlockedDecrement( &m_nRef );\r
346         if( ref == 0 )\r
347             delete this;\r
348 \r
349         return ref;\r
350     }\r
351 \r
352     HRESULT CClassFactory::CreateInstance(\r
353         IUnknown* pUnkOuter,\r
354         REFIID riid,\r
355         void** ppObject )\r
356     {\r
357         if( pUnkOuter != NULL )\r
358             return CLASS_E_NOAGGREGATION;\r
359 \r
360         if( IsEqualIID( riid, IID_INDProvider ) )\r
361         {\r
362             *ppObject = new CProvider();\r
363             if( !*ppObject )\r
364                 return E_OUTOFMEMORY;\r
365 \r
366             return S_OK;\r
367         }\r
368 \r
369         return E_NOINTERFACE;\r
370     }\r
371 \r
372     HRESULT CClassFactory::LockServer( BOOL fLock )\r
373     { \r
374         UNREFERENCED_PARAMETER( fLock );\r
375         return S_OK;\r
376     }\r
377 \r
378 } // namespace\r
379 \r
380 void* __cdecl operator new(\r
381     size_t count\r
382     )\r
383 {\r
384     return HeapAlloc( GetProcessHeap(), 0, count );\r
385 }\r
386 \r
387 \r
388 void __cdecl operator delete(\r
389     void* object\r
390     )\r
391 {\r
392     HeapFree( GetProcessHeap(), 0, object );\r
393 }\r
394 \r
395 extern "C" {\r
396 STDAPI DllGetClassObject(\r
397     REFCLSID rclsid,\r
398     REFIID riid,\r
399     LPVOID * ppv\r
400     )\r
401 {\r
402     ND_ENTER( ND_DBG_NDI );\r
403 \r
404     UNREFERENCED_PARAMETER( rclsid );\r
405 \r
406     if( IsEqualIID( riid, IID_IClassFactory ) )\r
407     {\r
408         NetworkDirect::CClassFactory* pFactory = new NetworkDirect::CClassFactory();\r
409         if( pFactory == NULL )\r
410             return E_OUTOFMEMORY;\r
411 \r
412         *ppv = pFactory;\r
413         return S_OK;\r
414     }\r
415 \r
416     return E_NOINTERFACE;\r
417 }\r
418 \r
419 STDAPI DllCanUnloadNow(void)\r
420 {\r
421     ND_ENTER( ND_DBG_NDI );\r
422 \r
423     if( InterlockedCompareExchange( &NetworkDirect::gnRef, 0, 0 ) != 0 )\r
424         return S_FALSE;\r
425 \r
426     return S_OK;\r
427 }\r
428 \r
429 int\r
430 WSPAPI\r
431 WSPStartup(\r
432     IN WORD wVersionRequested,\r
433     OUT LPWSPDATA lpWSPData,\r
434     IN LPWSAPROTOCOL_INFOW lpProtocolInfo,\r
435     IN WSPUPCALLTABLE UpcallTable,\r
436     OUT LPWSPPROC_TABLE lpProcTable\r
437     )\r
438 {\r
439     UNREFERENCED_PARAMETER( wVersionRequested );\r
440     UNREFERENCED_PARAMETER( lpWSPData );\r
441     UNREFERENCED_PARAMETER( lpProtocolInfo );\r
442     UNREFERENCED_PARAMETER( UpcallTable );\r
443     UNREFERENCED_PARAMETER( lpProcTable );\r
444     return WSASYSNOTREADY;\r
445 }\r
446 \r
447 static BOOL\r
448 _DllMain(\r
449     IN                HINSTANCE                    hinstDll,\r
450     IN                DWORD                        dwReason,\r
451     IN                LPVOID                        lpvReserved )\r
452 {\r
453 \r
454     ND_ENTER( ND_DBG_NDI );\r
455 \r
456     UNUSED_PARAM( hinstDll );\r
457     UNUSED_PARAM( lpvReserved );\r
458 \r
459     switch( dwReason )\r
460     {\r
461     case DLL_PROCESS_ATTACH:\r
462 \r
463 \r
464 #if defined(EVENT_TRACING)\r
465 #if DBG\r
466         WPP_INIT_TRACING(L"ibndprov.dll");\r
467 #else\r
468         WPP_INIT_TRACING(L"ibndprov.dll");\r
469 #endif\r
470 #endif\r
471 \r
472 \r
473 #if !defined(EVENT_TRACING)\r
474 #if DBG \r
475         TCHAR    env_var[16];\r
476         DWORD    i;\r
477 \r
478         i = GetEnvironmentVariable( "IBNDPROV_DBG_LEVEL", env_var, sizeof(env_var) );\r
479         if( i && i <= 16 )\r
480         {\r
481             g_nd_dbg_level = _tcstoul( env_var, NULL, 16 );\r
482         }\r
483 \r
484         i = GetEnvironmentVariable( "IBNDPROV_DBG_FLAGS", env_var, sizeof(env_var) );\r
485         if( i && i <= 16 )\r
486         {\r
487             g_nd_dbg_flags = _tcstoul( env_var, NULL, 16 );\r
488         }\r
489 \r
490         if( g_nd_dbg_flags & ND_DBG_ERR )\r
491             g_nd_dbg_flags |= CL_DBG_ERROR;\r
492 \r
493         ND_PRINT( TRACE_LEVEL_ERROR, ND_DBG_ERR ,\r
494             ("(pcs %#x) IbNdProv: Debug print: level:%d, flags 0x%x\n",\r
495             GetCurrentProcessId(), g_nd_dbg_level ,g_nd_dbg_flags) );\r
496 \r
497 #endif\r
498 #endif\r
499 \r
500         ND_PRINT(TRACE_LEVEL_INFORMATION, ND_DBG_NDI, ("DllMain: DLL_PROCESS_ATTACH\n") );\r
501         break;\r
502 \r
503     case DLL_PROCESS_DETACH:\r
504         ND_PRINT(TRACE_LEVEL_INFORMATION, ND_DBG_NDI,\r
505             ("DllMain: DLL_PROCESS_DETACH, ref count %d\n", NetworkDirect::gnRef) );\r
506 \r
507 #if defined(EVENT_TRACING)\r
508         WPP_CLEANUP();\r
509 #endif\r
510         break;\r
511     }\r
512 \r
513     ND_EXIT( ND_DBG_NDI );\r
514 \r
515     return TRUE;\r
516 }\r
517 \r
518 \r
519 extern BOOL APIENTRY\r
520 _DllMainCRTStartupForGS(\r
521     IN                HINSTANCE                    h_module,\r
522     IN                DWORD                        ul_reason_for_call, \r
523     IN                LPVOID                        lp_reserved );\r
524 \r
525 \r
526 BOOL APIENTRY\r
527 DllMain(\r
528     IN                HINSTANCE                    h_module,\r
529     IN                DWORD                        ul_reason_for_call, \r
530     IN                LPVOID                        lp_reserved )\r
531 {\r
532     switch( ul_reason_for_call )\r
533     {\r
534     case DLL_PROCESS_ATTACH:\r
535         if( !_DllMainCRTStartupForGS(\r
536             h_module, ul_reason_for_call, lp_reserved ) )\r
537         {\r
538             return FALSE;\r
539         }\r
540 \r
541         return _DllMain( h_module, ul_reason_for_call, lp_reserved );\r
542 \r
543     case DLL_THREAD_ATTACH:\r
544         ND_PRINT(TRACE_LEVEL_INFORMATION, ND_DBG_NDI, ("DllMain: DLL_THREAD_ATTACH\n") );\r
545         break;\r
546 \r
547     case DLL_THREAD_DETACH:\r
548         ND_PRINT(TRACE_LEVEL_INFORMATION, ND_DBG_NDI, ("DllMain: DLL_THREAD_DETACH\n") );\r
549         break;\r
550 \r
551     case DLL_PROCESS_DETACH:\r
552         _DllMain( h_module, ul_reason_for_call, lp_reserved );\r
553 \r
554         return _DllMainCRTStartupForGS(\r
555             h_module, ul_reason_for_call, lp_reserved );\r
556     }\r
557     return TRUE;\r
558 }\r
559 \r
560 }   // extern "C"\r
561 \r