Revert "[httpdisk] Apply WinVBlock-as-usual indentation"
[people/sha0/winvblock.git] / src / httpdisk / ksocket.c
1 /*
2     HTTP Virtual Disk.
3     Copyright (C) 2006 Bo Brantén.
4     This program is free software; you can redistribute it and/or modify
5     it under the terms of the GNU General Public License as published by
6     the Free Software Foundation; either version 2 of the License, or
7     (at your option) any later version.
8     This program is distributed in the hope that it will be useful,
9     but WITHOUT ANY WARRANTY; without even the implied warranty of
10     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11     GNU General Public License for more details.
12     You should have received a copy of the GNU General Public License
13     along with this program; if not, write to the Free Software
14     Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
15 */
16
17 #include <ntddk.h>
18 #include <tdikrnl.h>
19 #include "ktdi.h"
20 #include "ksocket.h"
21
22 typedef struct _STREAM_SOCKET {
23     HANDLE              connectionHandle;
24     PFILE_OBJECT        connectionFileObject;
25     KEVENT              disconnectEvent;
26 } STREAM_SOCKET, *PSTREAM_SOCKET;
27
28 typedef struct _SOCKET {
29     int                 type;
30     BOOLEAN             isBound;
31     BOOLEAN             isConnected;
32     BOOLEAN             isListening;
33     BOOLEAN             isShuttingdown;
34     BOOLEAN             isShared;
35     HANDLE              addressHandle;
36     PFILE_OBJECT        addressFileObject;
37     PSTREAM_SOCKET      streamSocket;
38     struct sockaddr     peer;
39 } SOCKET, *PSOCKET;
40
41 NTSTATUS event_disconnect(PVOID TdiEventContext, CONNECTION_CONTEXT ConnectionContext, LONG DisconnectDataLength,
42                           PVOID DisconnectData, LONG DisconnectInformationLength, PVOID DisconnectInformation,
43                           ULONG DisconnectFlags)
44 {
45     PSOCKET s = (PSOCKET) TdiEventContext;
46     PSTREAM_SOCKET streamSocket = (PSTREAM_SOCKET) ConnectionContext;
47     KeSetEvent(&streamSocket->disconnectEvent, 0, FALSE);
48     return STATUS_SUCCESS;
49 }
50
51 int __cdecl accept(int socket, struct sockaddr *addr, int *addrlen)
52 {
53     return -1;
54 }
55
56 int __cdecl bind(int socket, const struct sockaddr *addr, int addrlen)
57 {
58     PSOCKET s = (PSOCKET) -socket;
59     const struct sockaddr_in* localAddr = (const struct sockaddr_in*) addr;
60     UNICODE_STRING devName;
61     NTSTATUS status;
62
63     if (s->isBound || addr == NULL || addrlen < sizeof(struct sockaddr_in))
64     {
65         return -1;
66     }
67
68     if (s->type == SOCK_DGRAM)
69     {
70         RtlInitUnicodeString(&devName, L"\\Device\\Udp");
71     }
72     else if (s->type == SOCK_STREAM)
73     {
74         RtlInitUnicodeString(&devName, L"\\Device\\Tcp");
75     }
76     else
77     {
78         return -1;
79     }
80
81     status = tdi_open_transport_address(
82         &devName,
83         localAddr->sin_addr.s_addr,
84         localAddr->sin_port,
85         s->isShared,
86         &s->addressHandle,
87         &s->addressFileObject
88         );
89
90     if (!NT_SUCCESS(status))
91     {
92         s->addressFileObject = NULL;
93         s->addressHandle = (HANDLE) -1;
94         return status;
95     }
96
97     if (s->type == SOCK_STREAM)
98     {
99         tdi_set_event_handler(s->addressFileObject, TDI_EVENT_DISCONNECT, event_disconnect, s);
100     }
101
102     s->isBound = TRUE;
103
104     return 0;
105 }
106
107 int __cdecl close(int socket)
108 {
109     PSOCKET s = (PSOCKET) -socket;
110
111     if (s->isBound)
112     {
113         if (s->type == SOCK_STREAM && s->streamSocket)
114         {
115             if (s->isConnected)
116             {
117                 if (!s->isShuttingdown)
118                 {
119                     tdi_disconnect(s->streamSocket->connectionFileObject, TDI_DISCONNECT_RELEASE);
120                 }
121                 //KeWaitForSingleObject(&s->streamSocket->disconnectEvent, Executive, KernelMode, FALSE, NULL);
122             }
123             if (s->streamSocket->connectionFileObject)
124             {
125                 tdi_disassociate_address(s->streamSocket->connectionFileObject);
126                 ObDereferenceObject(s->streamSocket->connectionFileObject);
127             }
128             if (s->streamSocket->connectionHandle != (HANDLE) -1)
129             {
130                 ZwClose(s->streamSocket->connectionHandle);
131             }
132             ExFreePool(s->streamSocket);
133         }
134
135         if (s->type == SOCK_DGRAM || s->type == SOCK_STREAM)
136         {
137             ObDereferenceObject(s->addressFileObject);
138             if (s->addressHandle != (HANDLE) -1)
139             {
140                 ZwClose(s->addressHandle);
141             }
142         }
143     }
144
145     ExFreePool(s);
146
147     return 0;
148 }
149
150 int __cdecl connect(int socket, const struct sockaddr *addr, int addrlen)
151 {
152     PSOCKET s = (PSOCKET) -socket;
153     const struct sockaddr_in* remoteAddr = (const struct sockaddr_in*) addr;
154     UNICODE_STRING devName;
155     NTSTATUS status;
156
157     if (addr == NULL || addrlen < sizeof(struct sockaddr_in))
158     {
159         return -1;
160     }
161
162     if (!s->isBound)
163     {
164         struct sockaddr_in localAddr;
165
166         localAddr.sin_family = AF_INET;
167         localAddr.sin_port = 0;
168         localAddr.sin_addr.s_addr = INADDR_ANY;
169
170         status = bind(socket, (struct sockaddr*) &localAddr, sizeof(localAddr));
171
172         if (!NT_SUCCESS(status))
173         {
174             return status;
175         }
176     }
177
178     if (s->type == SOCK_STREAM)
179     {
180         if (s->isConnected || s->isListening)
181         {
182             return -1;
183         }
184
185         if (!s->streamSocket)
186         {
187             s->streamSocket = (PSTREAM_SOCKET) ExAllocatePool(NonPagedPool, sizeof(STREAM_SOCKET));
188
189             if (!s->streamSocket)
190             {
191                 return STATUS_INSUFFICIENT_RESOURCES;
192             }
193
194             RtlZeroMemory(s->streamSocket, sizeof(STREAM_SOCKET));
195             s->streamSocket->connectionHandle = (HANDLE) -1;
196             KeInitializeEvent(&s->streamSocket->disconnectEvent, NotificationEvent, FALSE);
197         }
198
199         RtlInitUnicodeString(&devName, L"\\Device\\Tcp");
200
201         status = tdi_open_connection_endpoint(
202             &devName,
203             s->streamSocket,
204             s->isShared,
205             &s->streamSocket->connectionHandle,
206             &s->streamSocket->connectionFileObject
207             );
208
209         if (!NT_SUCCESS(status))
210         {
211             s->streamSocket->connectionFileObject = NULL;
212             s->streamSocket->connectionHandle = (HANDLE) -1;
213             return status;
214         }
215
216         status = tdi_associate_address(s->streamSocket->connectionFileObject, s->addressHandle);
217
218         if (!NT_SUCCESS(status))
219         {
220             ObDereferenceObject(s->streamSocket->connectionFileObject);
221             s->streamSocket->connectionFileObject = NULL;
222             ZwClose(s->streamSocket->connectionHandle);
223             s->streamSocket->connectionHandle = (HANDLE) -1;
224             return status;
225         }
226
227         status = tdi_connect(
228             s->streamSocket->connectionFileObject,
229             remoteAddr->sin_addr.s_addr,
230             remoteAddr->sin_port
231             );
232
233         if (!NT_SUCCESS(status))
234         {
235             tdi_disassociate_address(s->streamSocket->connectionFileObject);
236             ObDereferenceObject(s->streamSocket->connectionFileObject);
237             s->streamSocket->connectionFileObject = NULL;
238             ZwClose(s->streamSocket->connectionHandle);
239             s->streamSocket->connectionHandle = (HANDLE) -1;
240             return status;
241         }
242         else
243         {
244             s->peer = *addr;
245             s->isConnected = TRUE;
246             return 0;
247         }
248     }
249     else if (s->type == SOCK_DGRAM)
250     {
251         s->peer = *addr;
252         if (remoteAddr->sin_addr.s_addr == 0 && remoteAddr->sin_port == 0)
253         {
254             s->isConnected = FALSE;
255         }
256         else
257         {
258             s->isConnected = TRUE;
259         }
260         return 0;
261     }
262     else
263     {
264         return -1;
265     }
266 }
267
268 int __cdecl getpeername(int socket, struct sockaddr *addr, int *addrlen)
269 {
270     PSOCKET s = (PSOCKET) -socket;
271
272     if (!s->isConnected || addr == NULL || addrlen == NULL || *addrlen < sizeof(struct sockaddr_in))
273     {
274         return -1;
275     }
276
277     *addr = s->peer;
278     *addrlen = sizeof(s->peer);
279
280     return 0;
281 }
282
283 int __cdecl getsockname(int socket, struct sockaddr *addr, int *addrlen)
284 {
285     PSOCKET s = (PSOCKET) -socket;
286     struct sockaddr_in* localAddr = (struct sockaddr_in*) addr;
287
288     if (!s->isBound || addr == NULL || addrlen == NULL || *addrlen < sizeof(struct sockaddr_in))
289     {
290         return -1;
291     }
292
293     if (s->type == SOCK_DGRAM)
294     {
295         *addrlen = sizeof(struct sockaddr_in);
296
297         return tdi_query_address(
298             s->addressFileObject,
299             &localAddr->sin_addr.s_addr,
300             &localAddr->sin_port
301             );
302     }
303     else if (s->type == SOCK_STREAM)
304     {
305         *addrlen = sizeof(struct sockaddr_in);
306
307         return tdi_query_address(
308             s->streamSocket && s->streamSocket->connectionFileObject ? s->streamSocket->connectionFileObject : s->addressFileObject,
309             &localAddr->sin_addr.s_addr,
310             &localAddr->sin_port
311             );
312     }
313     else
314     {
315         return -1;
316     }
317 }
318
319 int __cdecl getsockopt(int socket, int level, int optname, char *optval, int *optlen)
320 {
321     return -1;
322 }
323
324 int __cdecl listen(int socket, int backlog)
325 {
326     return -1;
327 }
328
329 int __cdecl recv(int socket, char *buf, int len, int flags)
330 {
331     PSOCKET s = (PSOCKET) -socket;
332
333     if (s->type == SOCK_DGRAM)
334     {
335         return recvfrom(socket, buf, len, flags, 0, 0);
336     }
337     else if (s->type == SOCK_STREAM)
338     {
339         if (!s->isConnected)
340         {
341             return -1;
342         }
343
344         return tdi_recv_stream(
345             s->streamSocket->connectionFileObject,
346             buf,
347             len,
348             flags == MSG_OOB ? TDI_RECEIVE_EXPEDITED : TDI_RECEIVE_NORMAL
349             );
350     }
351     else
352     {
353         return -1;
354     }
355 }
356
357 int __cdecl recvfrom(int socket, char *buf, int len, int flags, struct sockaddr *addr, int *addrlen)
358 {
359     PSOCKET s = (PSOCKET) -socket;
360     struct sockaddr_in* returnAddr = (struct sockaddr_in*) addr;
361
362     if (s->type == SOCK_STREAM)
363     {
364         return recv(socket, buf, len, flags);
365     }
366     else if (s->type == SOCK_DGRAM)
367     {
368         u_long* sin_addr = 0;
369         u_short* sin_port = 0;
370
371         if (!s->isBound)
372         {
373             return -1;
374         }
375
376         if (addr != NULL && addrlen != NULL && *addrlen >= sizeof(struct sockaddr_in))
377         {
378             sin_addr = &returnAddr->sin_addr.s_addr;
379             sin_port = &returnAddr->sin_port;
380             *addrlen = sizeof(struct sockaddr_in);
381         }
382
383         return tdi_recv_dgram(
384             s->addressFileObject,
385             sin_addr,
386             sin_port,
387             buf,
388             len,
389             TDI_RECEIVE_NORMAL
390             );
391     }
392     else
393     {
394         return -1;
395     }
396 }
397
398 int __cdecl select(int nfds, fd_set *readfds, fd_set *writefds, fd_set *exceptfds, const struct timeval *timeout)
399 {
400     return -1;
401 }
402
403 int __cdecl send(int socket, const char *buf, int len, int flags)
404 {
405     PSOCKET s = (PSOCKET) -socket;
406
407     if (!s->isConnected)
408     {
409         return -1;
410     }
411
412     if (s->type == SOCK_DGRAM)
413     {
414         return sendto(socket, buf, len, flags, &s->peer, sizeof(s->peer));
415     }
416     else if (s->type == SOCK_STREAM)
417     {
418         return tdi_send_stream(
419             s->streamSocket->connectionFileObject,
420             buf,
421             len,
422             flags == MSG_OOB ? TDI_SEND_EXPEDITED : 0
423             );
424     }
425     else
426     {
427         return -1;
428     }
429 }
430
431 int __cdecl sendto(int socket, const char *buf, int len, int flags, const struct sockaddr *addr, int addrlen)
432 {
433     PSOCKET s = (PSOCKET) -socket;
434     const struct sockaddr_in* remoteAddr = (const struct sockaddr_in*) addr;
435
436     if (s->type == SOCK_STREAM)
437     {
438         return send(socket, buf, len, flags);
439     }
440     else if (s->type == SOCK_DGRAM)
441     {
442         if (addr == NULL || addrlen < sizeof(struct sockaddr_in))
443         {
444             return -1;
445         }
446
447         if (!s->isBound)
448         {
449             struct sockaddr_in localAddr;
450             NTSTATUS status;
451
452             localAddr.sin_family = AF_INET;
453             localAddr.sin_port = 0;
454             localAddr.sin_addr.s_addr = INADDR_ANY;
455
456             status = bind(socket, (struct sockaddr*) &localAddr, sizeof(localAddr));
457
458             if (!NT_SUCCESS(status))
459             {
460                 return status;
461             }
462         }
463
464         return tdi_send_dgram(
465             s->addressFileObject,
466             remoteAddr->sin_addr.s_addr,
467             remoteAddr->sin_port,
468             buf,
469             len
470             );
471     }
472     else
473     {
474         return -1;
475     }
476 }
477
478 int __cdecl setsockopt(int socket, int level, int optname, const char *optval, int optlen)
479 {
480     return -1;
481 }
482
483 int __cdecl shutdown(int socket, int how)
484 {
485     PSOCKET s = (PSOCKET) -socket;
486
487     if (!s->isConnected)
488     {
489         return -1;
490     }
491
492     if (s->type == SOCK_STREAM)
493     {
494         s->isShuttingdown = TRUE;
495         return tdi_disconnect(s->streamSocket->connectionFileObject, TDI_DISCONNECT_RELEASE);
496     }
497     else
498     {
499         return -1;
500     }
501 }
502
503 int __cdecl socket(int af, int type, int protocol)
504 {
505     PSOCKET s;
506
507     if (af != AF_INET ||
508        (type != SOCK_DGRAM && type != SOCK_STREAM) ||
509        (type == SOCK_DGRAM && protocol != IPPROTO_UDP && protocol != 0) ||
510        (type == SOCK_STREAM && protocol != IPPROTO_TCP && protocol != 0)
511        )
512     {
513         return STATUS_INVALID_PARAMETER;
514     }
515
516     s = (PSOCKET) ExAllocatePool(NonPagedPool, sizeof(SOCKET));
517
518     if (!s)
519     {
520         return STATUS_INSUFFICIENT_RESOURCES;
521     }
522
523     RtlZeroMemory(s, sizeof(SOCKET));
524
525     s->type = type;
526     s->addressHandle = (HANDLE) -1;
527
528     return -(int)s;
529 }