001 // vim: set ft=c:
002 
003 class CUdpHeader {
004   U16 source_port;
005   U16 dest_port;
006   U16 length;
007   U16 checksum;
008 };
009 
010 class CUdpSocket {
011   CSocket sock;
012 
013   I64 rcvtimeo_ms;
014   I64 recv_maxtime;
015 
016   U8* recv_buf;
017   I64 recv_len;
018 
019   sockaddr_in recv_addr;
020   U16 bound_to;
021 };
022 
023 // TODO: this takes up half a meg, change it to a binary tree or something
024 static CUdpSocket** udp_bound_sockets;
025 
026 I64 UdpPacketAlloc(U8** frame_out, U32 source_ip, U16 source_port, U32 dest_ip, U16 dest_port, I64 length) {
027   U8* frame;
028   I64 index = IPv4PacketAlloc(&frame, IP_PROTO_UDP, source_ip, dest_ip, sizeof(CUdpHeader) + length);
029 
030   if (index < 0)
031     return index;
032 
033   CUdpHeader* hdr = frame;
034   hdr->source_port = htons(source_port);
035   hdr->dest_port = htons(dest_port);
036   hdr->length = htons(sizeof(CUdpHeader) + length);
037   hdr->checksum = 0;
038 
039   *frame_out = frame + sizeof(CUdpHeader);
040   return index;
041 }
042 
043 I64 UdpPacketFinish(I64 index) {
044   return IPv4PacketFinish(index);
045 }
046 
047 I64 UdpParsePacket(U16* source_port_out, U16* dest_port_out, U8** data_out, I64* length_out, CIPv4Packet* packet) {
048   if (packet->proto != IP_PROTO_UDP)
049     return -1;
050 
051   CUdpHeader* hdr = packet->data;
052   //"UDP: from %d, to %d, len %d, chksum %d\n",
053   //    ntohs(hdr->source_port), ntohs(hdr->dest_port), ntohs(hdr->length), ntohs(hdr->checksum);
054 
055   // FIXME: validate packet->length
056 
057   *source_port_out = ntohs(hdr->source_port);
058   *dest_port_out = ntohs(hdr->dest_port);
059   //ntohs(hdr->length)
060   //ntohs(hdr->checksum)
061 
062   *data_out = packet->data + sizeof(CUdpHeader);
063   *length_out = packet->length - sizeof(CUdpHeader);
064 
065   return 0;
066 }
067 
068 I64 UdpSocketAccept(CUdpSocket* s, sockaddr* addr, I64 addrlen) {
069   no_warn s;
070   no_warn addr;
071   no_warn addrlen;
072   return -1;
073 }
074 
075 I64 UdpSocketBind(CUdpSocket* s, sockaddr* addr, I64 addrlen) {
076   if (addrlen < sizeof(sockaddr_in))
077     return -1;
078 
079   if (s->bound_to)
080     return -1;
081 
082   sockaddr_in* addr_in = addr;
083   U16 port = ntohs(addr_in->sin_port);
084 
085   // TODO: address & stuff
086   if (udp_bound_sockets[port] != NULL)
087     return -1;
088 
089   udp_bound_sockets[port] = s;
090   s->bound_to = port;
091   return 0;
092 }
093 
094 I64 UdpSocketClose(CUdpSocket* s) {
095   if (s->bound_to)
096     udp_bound_sockets[s->bound_to] = NULL;
097 
098   Free(s);
099   return 0;
100 }
101 
102 I64 UdpSocketConnect(CUdpSocket* s, sockaddr* addr, I64 addrlen) {
103   // FIXME: implement
104   no_warn s;
105   no_warn addr;
106   no_warn addrlen;
107   return -1;
108 }
109 
110 I64 UdpSocketListen(CUdpSocket* s, I64 backlog) {
111   no_warn s;
112   no_warn backlog;
113   return -1;
114 }
115 
116 I64 UdpSocketRecvfrom(CUdpSocket* s, U8* buf, I64 len, I64 flags, sockaddr* src_addr, I64 addrlen) {
117   no_warn flags;
118 
119   s->recv_buf = buf;
120   s->recv_len = len;
121 
122   if (s->rcvtimeo_ms != 0)
123     s->recv_maxtime = cnts.jiffies + s->rcvtimeo_ms * JIFFY_FREQ / 1000;
124 
125   while (s->recv_buf != NULL) {
126     // Check for timeout
127     if (s->rcvtimeo_ms != 0 && cnts.jiffies > s->recv_maxtime) {
128       // TODO: seterror(EWOULDBLOCK)
129       s->recv_len = -1;
130       break;
131     }
132 
133     Yield;
134   }
135 
136   // TODO: addrlen
137   if (src_addr) {
138     // wtf? can't copy structs with '='?
139     MemCpy((src_addr(sockaddr_in*)), &s->recv_addr, addrlen);
140   }
141 
142   return s->recv_len;
143 }
144 
145 I64 UdpSocketSendto(CSocket* s, U8* buf, I64 len, I64 flags, sockaddr_in* dest_addr, I64 addrlen) {
146   no_warn s;
147   no_warn flags;
148 
149   if (addrlen < sizeof(sockaddr_in))
150     return -1;
151 
152   U8* frame;
153 
154   I64 index = UdpPacketAlloc(&frame, IPv4GetAddress(), 0, ntohl(dest_addr->sin_addr.s_addr),
155       ntohs(dest_addr->sin_port), len);
156 
157   if (index < 0)
158     return -1;
159 
160   MemCpy(frame, buf, len);
161   return UdpPacketFinish(index);
162 }
163 
164 I64 UdpSocketSetsockopt(CUdpSocket* s, I64 level, I64 optname, U8* optval, I64 optlen) {
165   if (level == SOL_SOCKET && optname == SO_RCVTIMEO_MS && optlen == 8) {
166     s->rcvtimeo_ms = *(optval(I64*));
167     return 0;
168   }
169 
170   return -1;
171 }
172 
173 CUdpSocket* UdpSocket(U16 domain, U16 type) {
174   if (domain != AF_INET || type != SOCK_DGRAM)
175     return NULL;
176 
177   CUdpSocket* s =       MAlloc(sizeof(CUdpSocket));
178   s->sock.accept =      &UdpSocketAccept;
179   s->sock.bind =        &UdpSocketBind;
180   s->sock.close =       &UdpSocketClose;
181   s->sock.connect =     &UdpSocketConnect;
182   s->sock.listen =      &UdpSocketListen;
183   s->sock.recvfrom =    &UdpSocketRecvfrom;
184   s->sock.sendto =      &UdpSocketSendto;
185   s->sock.setsockopt =  &UdpSocketSetsockopt;
186 
187   s->rcvtimeo_ms = 0;
188   s->recv_maxtime = 0;
189 
190   s->recv_buf = NULL;
191   s->recv_len = 0;
192   s->recv_addr.sin_family = AF_INET;
193   s->bound_to = 0;
194   return s;
195 }
196 
197 I64 UdpHandler(CIPv4Packet* packet) {
198   U16 source_port;
199   U16 dest_port;
200   U8* data;
201   I64 length;
202 
203   I64 error = UdpParsePacket(&source_port, &dest_port, &data, &length, packet);
204 
205   if (error < 0)
206     return error;
207 
208   //"%u => %p\n", dest_port, udp_bound_sockets[dest_port];
209 
210   CUdpSocket* s = udp_bound_sockets[dest_port];
211 
212   // FIXME: should also check that bound address is INADDR_ANY,
213   //        OR packet dest IP matches bound address
214   if (s != NULL) {
215     if (s->recv_buf) {
216       I64 num_recv = s->recv_len;
217 
218       if (num_recv > length)
219         num_recv = length;
220 
221       MemCpy(s->recv_buf, data, num_recv);
222 
223       // signal that we received something
224       s->recv_buf = NULL;
225       s->recv_len = num_recv;
226 
227       // TODO: we keep converting n>h>n, fuck that
228       s->recv_addr.sin_port = htons(source_port);
229       s->recv_addr.sin_addr.s_addr = htonl(packet->source_ip);
230     }
231   }
232 
233   return error;
234 }
235 
236 U0 UdpInit() {
237   udp_bound_sockets = MAlloc(65536 * sizeof(CUdpSocket*));
238   MemSet(udp_bound_sockets, 0, 65536 * sizeof(CUdpSocket*));
239 }
240 
241 UdpInit;
242 RegisterL4Protocol(IP_PROTO_UDP, &UdpHandler);
243 RegisterSocketClass(AF_INET, SOCK_DGRAM, &UdpSocket);