d5ec67fbe5545c5e9eabfd02af97142d1240aac9
[people/sha0/winvblock.git] / src / httpdisk / ksocket.c
1 /*
2     HTTP Virtual Disk.
3     Copyright (C) 2006 Bo Brantén.
4     This program is free software; you can redistribute it and/or modify
5     it under the terms of the GNU General Public License as published by
6     the Free Software Foundation; either version 2 of the License, or
7     (at your option) any later version.
8     This program is distributed in the hope that it will be useful,
9     but WITHOUT ANY WARRANTY; without even the implied warranty of
10     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11     GNU General Public License for more details.
12     You should have received a copy of the GNU General Public License
13     along with this program; if not, write to the Free Software
14     Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15 */
16
17 #include <ntddk.h>
18 #include <tdikrnl.h>
19 #include "ktdi.h"
20 #include "ksocket.h"
21
22 typedef struct _STREAM_SOCKET
23 {
24   HANDLE connectionHandle;
25   PFILE_OBJECT connectionFileObject;
26   KEVENT disconnectEvent;
27 } STREAM_SOCKET,
28 *PSTREAM_SOCKET;
29
30 typedef struct _SOCKET
31 {
32   int type;
33   BOOLEAN isBound;
34   BOOLEAN isConnected;
35   BOOLEAN isListening;
36   BOOLEAN isShuttingdown;
37   BOOLEAN isShared;
38   HANDLE addressHandle;
39   PFILE_OBJECT addressFileObject;
40   PSTREAM_SOCKET streamSocket;
41   struct sockaddr peer;
42 } SOCKET,
43 *PSOCKET;
44
45 NTSTATUS
46 event_disconnect (
47   PVOID TdiEventContext,
48   CONNECTION_CONTEXT ConnectionContext,
49   LONG DisconnectDataLength,
50   PVOID DisconnectData,
51   LONG DisconnectInformationLength,
52   PVOID DisconnectInformation,
53   ULONG DisconnectFlags
54  )
55 {
56   PSOCKET s = ( PSOCKET ) TdiEventContext;
57   PSTREAM_SOCKET streamSocket = ( PSTREAM_SOCKET ) ConnectionContext;
58   KeSetEvent ( &streamSocket->disconnectEvent, 0, FALSE );
59   return STATUS_SUCCESS;
60 }
61
62 int __cdecl
63 accept (
64   int socket,
65   struct sockaddr *addr,
66   int *addrlen
67  )
68 {
69   return -1;
70 }
71
72 int __cdecl
73 bind (
74   int socket,
75   const struct sockaddr *addr,
76   int addrlen
77  )
78 {
79   PSOCKET s = ( PSOCKET ) - socket;
80   const struct sockaddr_in *localAddr = ( const struct sockaddr_in * )addr;
81   UNICODE_STRING devName;
82   NTSTATUS status;
83
84   if ( s->isBound || addr == NULL || addrlen < sizeof ( struct sockaddr_in ) )
85     {
86       return -1;
87     }
88
89   if ( s->type == SOCK_DGRAM )
90     {
91       RtlInitUnicodeString ( &devName, L"\\Device\\Udp" );
92     }
93   else if ( s->type == SOCK_STREAM )
94     {
95       RtlInitUnicodeString ( &devName, L"\\Device\\Tcp" );
96     }
97   else
98     {
99       return -1;
100     }
101
102   status =
103     tdi_open_transport_address ( &devName, localAddr->sin_addr.s_addr,
104                                  localAddr->sin_port, s->isShared,
105                                  &s->addressHandle, &s->addressFileObject );
106
107   if ( !NT_SUCCESS ( status ) )
108     {
109       s->addressFileObject = NULL;
110       s->addressHandle = ( HANDLE ) - 1;
111       return status;
112     }
113
114   if ( s->type == SOCK_STREAM )
115     {
116       tdi_set_event_handler ( s->addressFileObject, TDI_EVENT_DISCONNECT,
117                               event_disconnect, s );
118     }
119
120   s->isBound = TRUE;
121
122   return 0;
123 }
124
125 int __cdecl
126 close (
127   int socket
128  )
129 {
130   PSOCKET s = ( PSOCKET ) - socket;
131
132   if ( s->isBound )
133     {
134       if ( s->type == SOCK_STREAM && s->streamSocket )
135         {
136           if ( s->isConnected )
137             {
138               if ( !s->isShuttingdown )
139                 {
140                   tdi_disconnect ( s->streamSocket->connectionFileObject,
141                                    TDI_DISCONNECT_RELEASE );
142                 }
143               //KeWaitForSingleObject(&s->streamSocket->disconnectEvent, Executive, KernelMode, FALSE, NULL);
144             }
145           if ( s->streamSocket->connectionFileObject )
146             {
147               tdi_disassociate_address ( s->
148                                          streamSocket->connectionFileObject );
149               ObDereferenceObject ( s->streamSocket->connectionFileObject );
150             }
151           if ( s->streamSocket->connectionHandle != ( HANDLE ) - 1 )
152             {
153               ZwClose ( s->streamSocket->connectionHandle );
154             }
155           ExFreePool ( s->streamSocket );
156         }
157
158       if ( s->type == SOCK_DGRAM || s->type == SOCK_STREAM )
159         {
160           ObDereferenceObject ( s->addressFileObject );
161           if ( s->addressHandle != ( HANDLE ) - 1 )
162             {
163               ZwClose ( s->addressHandle );
164             }
165         }
166     }
167
168   ExFreePool ( s );
169
170   return 0;
171 }
172
173 int __cdecl
174 connect (
175   int socket,
176   const struct sockaddr *addr,
177   int addrlen
178  )
179 {
180   PSOCKET s = ( PSOCKET ) - socket;
181   const struct sockaddr_in *remoteAddr = ( const struct sockaddr_in * )addr;
182   UNICODE_STRING devName;
183   NTSTATUS status;
184
185   if ( addr == NULL || addrlen < sizeof ( struct sockaddr_in ) )
186     {
187       return -1;
188     }
189
190   if ( !s->isBound )
191     {
192       struct sockaddr_in localAddr;
193
194       localAddr.sin_family = AF_INET;
195       localAddr.sin_port = 0;
196       localAddr.sin_addr.s_addr = INADDR_ANY;
197
198       status =
199         bind ( socket, ( struct sockaddr * )&localAddr, sizeof ( localAddr ) );
200
201       if ( !NT_SUCCESS ( status ) )
202         {
203           return status;
204         }
205     }
206
207   if ( s->type == SOCK_STREAM )
208     {
209       if ( s->isConnected || s->isListening )
210         {
211           return -1;
212         }
213
214       if ( !s->streamSocket )
215         {
216           s->streamSocket =
217             ( PSTREAM_SOCKET ) ExAllocatePool ( NonPagedPool,
218                                                 sizeof ( STREAM_SOCKET ) );
219
220           if ( !s->streamSocket )
221             {
222               return STATUS_INSUFFICIENT_RESOURCES;
223             }
224
225           RtlZeroMemory ( s->streamSocket, sizeof ( STREAM_SOCKET ) );
226           s->streamSocket->connectionHandle = ( HANDLE ) - 1;
227           KeInitializeEvent ( &s->streamSocket->disconnectEvent,
228                               NotificationEvent, FALSE );
229         }
230
231       RtlInitUnicodeString ( &devName, L"\\Device\\Tcp" );
232
233       status =
234         tdi_open_connection_endpoint ( &devName, s->streamSocket, s->isShared,
235                                        &s->streamSocket->connectionHandle,
236                                        &s->
237                                        streamSocket->connectionFileObject );
238
239       if ( !NT_SUCCESS ( status ) )
240         {
241           s->streamSocket->connectionFileObject = NULL;
242           s->streamSocket->connectionHandle = ( HANDLE ) - 1;
243           return status;
244         }
245
246       status =
247         tdi_associate_address ( s->streamSocket->connectionFileObject,
248                                 s->addressHandle );
249
250       if ( !NT_SUCCESS ( status ) )
251         {
252           ObDereferenceObject ( s->streamSocket->connectionFileObject );
253           s->streamSocket->connectionFileObject = NULL;
254           ZwClose ( s->streamSocket->connectionHandle );
255           s->streamSocket->connectionHandle = ( HANDLE ) - 1;
256           return status;
257         }
258
259       status =
260         tdi_connect ( s->streamSocket->connectionFileObject,
261                       remoteAddr->sin_addr.s_addr, remoteAddr->sin_port );
262
263       if ( !NT_SUCCESS ( status ) )
264         {
265           tdi_disassociate_address ( s->streamSocket->connectionFileObject );
266           ObDereferenceObject ( s->streamSocket->connectionFileObject );
267           s->streamSocket->connectionFileObject = NULL;
268           ZwClose ( s->streamSocket->connectionHandle );
269           s->streamSocket->connectionHandle = ( HANDLE ) - 1;
270           return status;
271         }
272       else
273         {
274           s->peer = *addr;
275           s->isConnected = TRUE;
276           return 0;
277         }
278     }
279   else if ( s->type == SOCK_DGRAM )
280     {
281       s->peer = *addr;
282       if ( remoteAddr->sin_addr.s_addr == 0 && remoteAddr->sin_port == 0 )
283         {
284           s->isConnected = FALSE;
285         }
286       else
287         {
288           s->isConnected = TRUE;
289         }
290       return 0;
291     }
292   else
293     {
294       return -1;
295     }
296 }
297
298 int __cdecl
299 getpeername (
300   int socket,
301   struct sockaddr *addr,
302   int *addrlen
303  )
304 {
305   PSOCKET s = ( PSOCKET ) - socket;
306
307   if ( !s->isConnected || addr == NULL || addrlen == NULL
308        || *addrlen < sizeof ( struct sockaddr_in ) )
309     {
310       return -1;
311     }
312
313   *addr = s->peer;
314   *addrlen = sizeof ( s->peer );
315
316   return 0;
317 }
318
319 int __cdecl
320 getsockname (
321   int socket,
322   struct sockaddr *addr,
323   int *addrlen
324  )
325 {
326   PSOCKET s = ( PSOCKET ) - socket;
327   struct sockaddr_in *localAddr = ( struct sockaddr_in * )addr;
328
329   if ( !s->isBound || addr == NULL || addrlen == NULL
330        || *addrlen < sizeof ( struct sockaddr_in ) )
331     {
332       return -1;
333     }
334
335   if ( s->type == SOCK_DGRAM )
336     {
337       *addrlen = sizeof ( struct sockaddr_in );
338
339       return tdi_query_address ( s->addressFileObject,
340                                  &localAddr->sin_addr.s_addr,
341                                  &localAddr->sin_port );
342     }
343   else if ( s->type == SOCK_STREAM )
344     {
345       *addrlen = sizeof ( struct sockaddr_in );
346
347       return tdi_query_address ( s->streamSocket
348                                  && s->streamSocket->
349                                  connectionFileObject ? s->streamSocket->
350                                  connectionFileObject : s->addressFileObject,
351                                  &localAddr->sin_addr.s_addr,
352                                  &localAddr->sin_port );
353     }
354   else
355     {
356       return -1;
357     }
358 }
359
360 int __cdecl
361 getsockopt (
362   int socket,
363   int level,
364   int optname,
365   char *optval,
366   int *optlen
367  )
368 {
369   return -1;
370 }
371
372 int __cdecl
373 listen (
374   int socket,
375   int backlog
376  )
377 {
378   return -1;
379 }
380
381 int __cdecl
382 recv (
383   int socket,
384   char *buf,
385   int len,
386   int flags
387  )
388 {
389   PSOCKET s = ( PSOCKET ) - socket;
390
391   if ( s->type == SOCK_DGRAM )
392     {
393       return recvfrom ( socket, buf, len, flags, 0, 0 );
394     }
395   else if ( s->type == SOCK_STREAM )
396     {
397       if ( !s->isConnected )
398         {
399           return -1;
400         }
401
402       return tdi_recv_stream ( s->streamSocket->connectionFileObject, buf, len,
403                                flags ==
404                                MSG_OOB ? TDI_RECEIVE_EXPEDITED :
405                                TDI_RECEIVE_NORMAL );
406     }
407   else
408     {
409       return -1;
410     }
411 }
412
413 int __cdecl
414 recvfrom (
415   int socket,
416   char *buf,
417   int len,
418   int flags,
419   struct sockaddr *addr,
420   int *addrlen
421  )
422 {
423   PSOCKET s = ( PSOCKET ) - socket;
424   struct sockaddr_in *returnAddr = ( struct sockaddr_in * )addr;
425
426   if ( s->type == SOCK_STREAM )
427     {
428       return recv ( socket, buf, len, flags );
429     }
430   else if ( s->type == SOCK_DGRAM )
431     {
432       u_long *sin_addr = 0;
433       u_short *sin_port = 0;
434
435       if ( !s->isBound )
436         {
437           return -1;
438         }
439
440       if ( addr != NULL && addrlen != NULL
441            && *addrlen >= sizeof ( struct sockaddr_in ) )
442         {
443           sin_addr = &returnAddr->sin_addr.s_addr;
444           sin_port = &returnAddr->sin_port;
445           *addrlen = sizeof ( struct sockaddr_in );
446         }
447
448       return tdi_recv_dgram ( s->addressFileObject, sin_addr, sin_port, buf,
449                               len, TDI_RECEIVE_NORMAL );
450     }
451   else
452     {
453       return -1;
454     }
455 }
456
457 int __cdecl
458 select (
459   int nfds,
460   fd_set * readfds,
461   fd_set * writefds,
462   fd_set * exceptfds,
463   const struct timeval *timeout
464  )
465 {
466   return -1;
467 }
468
469 int __cdecl
470 send (
471   int socket,
472   const char *buf,
473   int len,
474   int flags
475  )
476 {
477   PSOCKET s = ( PSOCKET ) - socket;
478
479   if ( !s->isConnected )
480     {
481       return -1;
482     }
483
484   if ( s->type == SOCK_DGRAM )
485     {
486       return sendto ( socket, buf, len, flags, &s->peer, sizeof ( s->peer ) );
487     }
488   else if ( s->type == SOCK_STREAM )
489     {
490       return tdi_send_stream ( s->streamSocket->connectionFileObject, buf, len,
491                                flags == MSG_OOB ? TDI_SEND_EXPEDITED : 0 );
492     }
493   else
494     {
495       return -1;
496     }
497 }
498
499 int __cdecl
500 sendto (
501   int socket,
502   const char *buf,
503   int len,
504   int flags,
505   const struct sockaddr *addr,
506   int addrlen
507  )
508 {
509   PSOCKET s = ( PSOCKET ) - socket;
510   const struct sockaddr_in *remoteAddr = ( const struct sockaddr_in * )addr;
511
512   if ( s->type == SOCK_STREAM )
513     {
514       return send ( socket, buf, len, flags );
515     }
516   else if ( s->type == SOCK_DGRAM )
517     {
518       if ( addr == NULL || addrlen < sizeof ( struct sockaddr_in ) )
519         {
520           return -1;
521         }
522
523       if ( !s->isBound )
524         {
525           struct sockaddr_in localAddr;
526           NTSTATUS status;
527
528           localAddr.sin_family = AF_INET;
529           localAddr.sin_port = 0;
530           localAddr.sin_addr.s_addr = INADDR_ANY;
531
532           status =
533             bind ( socket, ( struct sockaddr * )&localAddr,
534                    sizeof ( localAddr ) );
535
536           if ( !NT_SUCCESS ( status ) )
537             {
538               return status;
539             }
540         }
541
542       return tdi_send_dgram ( s->addressFileObject,
543                               remoteAddr->sin_addr.s_addr,
544                               remoteAddr->sin_port, buf, len );
545     }
546   else
547     {
548       return -1;
549     }
550 }
551
552 int __cdecl
553 setsockopt (
554   int socket,
555   int level,
556   int optname,
557   const char *optval,
558   int optlen
559  )
560 {
561   return -1;
562 }
563
564 int __cdecl
565 shutdown (
566   int socket,
567   int how
568  )
569 {
570   PSOCKET s = ( PSOCKET ) - socket;
571
572   if ( !s->isConnected )
573     {
574       return -1;
575     }
576
577   if ( s->type == SOCK_STREAM )
578     {
579       s->isShuttingdown = TRUE;
580       return tdi_disconnect ( s->streamSocket->connectionFileObject,
581                               TDI_DISCONNECT_RELEASE );
582     }
583   else
584     {
585       return -1;
586     }
587 }
588
589 int __cdecl
590 socket (
591   int af,
592   int type,
593   int protocol
594  )
595 {
596   PSOCKET s;
597
598   if ( af != AF_INET || ( type != SOCK_DGRAM && type != SOCK_STREAM )
599        || ( type == SOCK_DGRAM && protocol != IPPROTO_UDP && protocol != 0 )
600        || ( type == SOCK_STREAM && protocol != IPPROTO_TCP && protocol != 0 ) )
601     {
602       return STATUS_INVALID_PARAMETER;
603     }
604
605   s = ( PSOCKET ) ExAllocatePool ( NonPagedPool, sizeof ( SOCKET ) );
606
607   if ( !s )
608     {
609       return STATUS_INSUFFICIENT_RESOURCES;
610     }
611
612   RtlZeroMemory ( s, sizeof ( SOCKET ) );
613
614   s->type = type;
615   s->addressHandle = ( HANDLE ) - 1;
616
617   return -( int )s;
618 }