001 // vim: set ft=c:
002 
003 #define DNS_RCODE_NO_ERROR        0
004 #define DNS_RCODE_FORMAT_ERROR    1
005 #define DNS_RCODE_SERVER_FAILURE  2
006 #define DNS_RCODE_NAME_ERROR      3
007 #define DNS_RCODE_NOT_IMPLEMENTED 5
008 #define DNS_RCODE_REFUSED         6
009 
010 #define DNS_FLAG_RA         0x0080
011 #define DNS_FLAG_RD         0x0100
012 #define DNS_FLAG_TC         0x0200
013 #define DNS_FLAG_AA         0x0400
014 
015 #define DNS_OP_QUERY        0
016 #define DNS_OP_IQUERY       1
017 #define DNS_OP_STATUS       2
018 
019 #define DNS_FLAG_QR         0x8000
020 
021 // http://www.freesoft.org/CIE/RFC/1035/14.htm
022 #define DNS_TYPE_A          1
023 #define DNS_TYPE_NS         2
024 #define DNS_TYPE_CNAME      5
025 #define DNS_TYPE_PTR        12
026 #define DNS_TYPE_MX         15
027 #define DNS_TYPE_TXT        16
028 
029 // http://www.freesoft.org/CIE/RFC/1035/16.htm
030 #define DNS_CLASS_IN        1
031 
032 #define DNS_TIMEOUT         5000
033 #define DNS_MAX_RETRIES     3
034 
035 class CDnsCacheEntry {
036   CDnsCacheEntry* next;
037   U8* hostname;
038   addrinfo info;
039   // TODO: honor TTL
040 };
041 
042 class CDnsHeader {
043   U16 id;
044   U16 flags;
045   U16 qdcount;
046   U16 ancount;
047   U16 nscount;
048   U16 arcount;
049 };
050 
051 class CDnsDomainName {
052   U8** labels;
053   I64 num_labels;
054 }
055 
056 class CDnsQuestion {
057   CDnsQuestion* next;
058 
059   CDnsDomainName qname;
060   U16 qtype;
061   U16 qclass;
062 };
063 
064 class CDnsRR {
065   CDnsRR* next;
066 
067   CDnsDomainName name;
068   U16 type;
069   U16 class_;
070   U32 ttl;
071   U16 rdlength;
072   U8* rdata;
073 };
074 
075 // TODO: use a Hash table
076 static CDnsCacheEntry* dns_cache = NULL;
077 
078 static U32 dns_ip = 0;
079 
080 static CDnsCacheEntry* DnsCacheFind(U8* hostname) {
081   CDnsCacheEntry* e = dns_cache;
082 
083   while (e) {
084     if (!StrCmp(e->hostname, hostname))
085       return e;
086 
087     e = e->next;
088   }
089 
090   return e;
091 }
092 
093 static CDnsCacheEntry* DnsCachePut(U8* hostname, addrinfo* info) {
094   CDnsCacheEntry* e = DnsCacheFind(hostname);
095 
096   if (!e) {
097     e = MAlloc(sizeof(CDnsCacheEntry));
098     e->next = dns_cache;
099     e->hostname = StrNew(hostname);
100     AddrInfoCopy(&e->info, info);
101 
102     dns_cache = e;
103   }
104 
105   return e;
106 }
107 
108 static I64 DnsCalcQuestionSize(CDnsQuestion* question) {
109   I64 size = 0;
110   I64 i;
111   for (i = 0; i < question->qname.num_labels; i++) {
112     size += 1 + StrLen(question->qname.labels[i]);
113   }
114   return size + 1 + 4;
115 }
116 
117 static U0 DnsSerializeQuestion(U8* buf, CDnsQuestion* question) {
118   I64 i;
119 
120   for (i = 0; i < question->qname.num_labels; i++) {
121     U8* label = question->qname.labels[i];
122     *(buf++) = StrLen(label);
123 
124     while (*label)
125       *(buf++) = *(label++);
126   }
127 
128   *(buf++) = 0;
129   *(buf++) = (question->qtype >> 8);
130   *(buf++) = (question->qtype & 0xff);
131   *(buf++) = (question->qclass >> 8);
132   *(buf++) = (question->qclass & 0xff);
133 }
134 
135 static I64 DnsSendQuestion(U16 id, U16 local_port, CDnsQuestion* question) {
136   if (!dns_ip)
137     return -1;
138 
139   U8* frame;
140   I64 index = UdpPacketAlloc(&frame, IPv4GetAddress(), local_port, dns_ip, 53,
141       sizeof(CDnsHeader) + DnsCalcQuestionSize(question));
142 
143   if (index < 0)
144     return index;
145 
146   U16 flags = (DNS_OP_QUERY << 11) | DNS_FLAG_RD;
147 
148   CDnsHeader* hdr = frame;
149   hdr->id = htons(id);
150   hdr->flags = htons(flags);
151   hdr->qdcount = htons(1);
152   hdr->ancount = 0;
153   hdr->nscount = 0;
154   hdr->arcount = 0;
155 
156   DnsSerializeQuestion(frame + sizeof(CDnsHeader), question);
157 
158   return UdpPacketFinish(index);
159 }
160 
161 static I64 DnsParseDomainName(U8* packet_data, I64 packet_length,
162     U8** data_inout, I64* length_inout, CDnsDomainName* name_out) {
163   U8* data = *data_inout;
164   I64 length = *length_inout;
165   Bool jump_taken = FALSE;
166 
167   if (length < 1) {
168     //"DnsParseDomainName: EOF\n";
169     return -1;
170   }
171 
172   name_out->labels = MAlloc(16 * sizeof(U8*));
173   name_out->num_labels = 0;
174 
175   U8* name_buf = MAlloc(256);
176   name_out->labels[0] = name_buf;
177 
178   while (length) {
179     I64 label_len = *(data++);
180     length--;
181 
182     if (label_len == 0) {
183       break;
184     }
185     else if (label_len >= 192) {
186       label_len &= 0x3f;
187 
188       if (!jump_taken) {
189         *data_inout = data + 1;
190         *length_inout = length - 1;
191         jump_taken = TRUE;
192       }
193 
194       //"jmp %d\n", ((label_len << 8) | *data);
195 
196       data = packet_data + ((label_len << 8) | *data);
197       length = packet_data + packet_length - data;
198     }
199     else {
200       if (length < label_len) return -1;
201 
202       MemCpy(name_buf, data, label_len);
203       data += label_len;
204       length -= label_len;
205 
206       name_buf[label_len] = 0;
207       //"%d bytes => %s\n", label_len, name_buf;
208       name_out->labels[name_out->num_labels++] = name_buf;
209 
210       name_buf += label_len + 1;
211     }
212   }
213 
214   if (!jump_taken) {
215     *data_inout = data;
216     *length_inout = length;
217   }
218 
219   return 0;
220 }
221 
222 static I64 DnsParseQuestion(U8* packet_data, I64 packet_length,
223     U8** data_inout, I64* length_inout, CDnsQuestion* question_out) {
224   I64 error = DnsParseDomainName(packet_data, packet_length,
225       data_inout, length_inout, &question_out->qname);
226 
227   if (error < 0)
228     return error;
229 
230   U8* data = *data_inout;
231   I64 length = *length_inout;
232 
233   if (length < 4)
234     return -1;
235 
236   question_out->next = NULL;
237   question_out->qtype = (data[1] << 8) | data[0];
238   question_out->qclass = (data[3] << 8) | data[2];
239 
240   //"DnsParseQuestion: qtype %d, qclass %d\n", ntohs(question_out->qtype), ntohs(question_out->qclass);
241 
242   *data_inout = data + 4;
243   *length_inout = length - 4;
244   return 0;
245 }
246 
247 static I64 DnsParseRR(U8* packet_data, I64 packet_length,
248     U8** data_inout, I64* length_inout, CDnsRR* rr_out) {
249   I64 error = DnsParseDomainName(packet_data, packet_length,
250       data_inout, length_inout, &rr_out->name);
251 
252   if (error < 0)
253     return error;
254 
255   U8* data = *data_inout;
256   I64 length = *length_inout;
257 
258   if (length < 10)
259     return -1;
260 
261   rr_out->next = NULL;
262   MemCpy(&rr_out->type, data, 10);
263 
264   I64 record_length = 10 + ntohs(rr_out->rdlength);
265 
266   if (length < record_length)
267     return -1;
268 
269   rr_out->rdata = data + 10;
270 
271   //"DnsParseRR: type %d, class %d\n, ttl %d, rdlength %d\n",
272   //    ntohs(rr_out->type), ntohs(rr_out->class_), ntohl(rr_out->ttl), ntohs(rr_out->rdlength);
273 
274   *data_inout = data + record_length;
275   *length_inout = length - record_length;
276   return 0;
277 }
278 
279 static I64 DnsParseResponse(U16 id, U8* data, I64 length,
280     CDnsHeader** hdr_out, CDnsQuestion** questions_out,
281     CDnsRR** answers_out) {
282   U8* packet_data = data;
283   I64 packet_length = length;
284 
285   if (length < sizeof(CDnsHeader)) {
286     //"DnsParseResponse: too short\n";
287     return -1;
288   }
289 
290   CDnsHeader* hdr = data;
291   data += sizeof(CDnsHeader);
292 
293   if (id != 0 && ntohs(hdr->id) != id) {
294     //"DnsParseResponse: id %04Xh != %04Xh\n", ntohs(hdr->id), id;
295     return -1;
296   }
297 
298   I64 i;
299 
300   for (i = 0; i < htons(hdr->qdcount); i++) {
301     CDnsQuestion* question = MAlloc(sizeof(CDnsQuestion));
302     if (DnsParseQuestion(packet_data, packet_length, &data, &length, question) < 0)
303       return -1;
304 
305     question->next = *questions_out;
306     *questions_out = question;
307   }
308 
309   for (i = 0; i < htons(hdr->ancount); i++) {
310     CDnsRR* answer = MAlloc(sizeof(CDnsRR));
311     if (DnsParseRR(packet_data, packet_length, &data, &length, answer) < 0)
312       return -1;
313 
314     answer->next = *answers_out;
315     *answers_out = answer;
316   }
317 
318   *hdr_out = hdr;
319   return 0;
320 }
321 
322 static U0 DnsBuildQuestion(CDnsQuestion* question, U8* name) {
323   question->next = NULL;
324   question->qname.labels = MAlloc(16 * sizeof(U8*));
325   question->qname.labels[0] = 0;
326   question->qname.num_labels = 0;
327   question->qtype = DNS_TYPE_A;
328   question->qclass = DNS_CLASS_IN;
329 
330   U8* copy = StrNew(name);
331 
332   while (*copy) {
333     question->qname.labels[question->qname.num_labels++] = copy;
334     U8* dot = StrFirstOcc(copy, ".");
335 
336     if (dot) {
337       *dot = 0;
338       copy = dot + 1;
339     }
340     else
341       break;
342   }
343 }
344 
345 static U0 DnsFreeQuestion(CDnsQuestion* question) {
346   Free(question->qname.labels[0]);
347 }
348 
349 static U0 DnsFreeRR(CDnsRR* rr) {
350   Free(rr->name.labels[0]);
351 }
352 
353 static U0 DnsFreeQuestionChain(CDnsQuestion* questions) {
354   while (questions) {
355     CDnsQuestion* next = questions->next;
356     DnsFreeQuestion(questions);
357     Free(questions);
358     questions = next;
359   }
360 }
361 
362 static U0 DnsFreeRRChain(CDnsRR* rrs) {
363   while (rrs) {
364     CDnsQuestion* next = rrs->next;
365     DnsFreeRR(rrs);
366     Free(rrs);
367     rrs = next;
368   }
369 }
370 
371 static I64 DnsRunQuery(I64 sock, U8* name, U16 port, addrinfo** res_out) {
372   I64 retries = 0;
373   I64 timeout = DNS_TIMEOUT;
374 
375   if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO_MS, &timeout, sizeof(timeout)) < 0) {
376     "DnsRunQuery: setsockopt failed\n";
377   }
378 
379   U16 local_port = RandU16();
380 
381   sockaddr_in addr;
382   addr.sin_family = AF_INET;
383   addr.sin_port = htons(local_port);
384   addr.sin_addr.s_addr = INADDR_ANY;
385 
386   if (bind(sock, &addr, sizeof(addr)) < 0) {
387     "DnsRunQuery: failed to bind\n";
388     return -1;
389   }
390 
391   U8 buffer[2048];
392 
393   I64 count;
394   sockaddr_in addr_in;
395 
396   U16 id = RandU16();
397   I64 error = 0;
398 
399   CDnsQuestion question;
400   DnsBuildQuestion(&question, name);
401 
402   while (1) {
403     error = DnsSendQuestion(id, local_port, &question);
404     if (error < 0) return error;
405 
406     count = recvfrom(sock, buffer, sizeof(buffer), 0, &addr_in, sizeof(addr_in));
407 
408     if (count > 0) {
409       //"Try parse response\n";
410       CDnsHeader* hdr = NULL;
411       CDnsQuestion* questions = NULL;
412       CDnsRR* answers = NULL;
413 
414       error = DnsParseResponse(id, buffer, count, &hdr, &questions, &answers);
415 
416       if (error >= 0) {
417         Bool have = FALSE;
418 
419         // Look for a suitable A-record in the answer
420         CDnsRR* answer = answers;
421         while (answer) {
422           // TODO: if there are multiple acceptable answers,
423           //       we should pick one at random -- not just the first one
424           if (htons(answer->type) == DNS_TYPE_A
425               && htons(answer->class_) == DNS_CLASS_IN
426               && htons(answer->rdlength) == 4) {
427             addrinfo* res = MAlloc(sizeof(addrinfo));
428             res->ai_flags = 0;
429             res->ai_family = AF_INET;
430             res->ai_socktype = 0;
431             res->ai_protocol = 0;
432             res->ai_addrlen = sizeof(sockaddr_in);
433             res->ai_addr = MAlloc(sizeof(sockaddr_in));
434             res->ai_canonname = NULL;
435             res->ai_next = NULL;
436 
437             sockaddr_in* sa = res->ai_addr;
438             sa->sin_family = AF_INET;
439             sa->sin_port = port;
440             MemCpy(&sa->sin_addr.s_addr, answers->rdata, 4);
441 
442             DnsCachePut(name, res);
443             *res_out = res;
444             have = TRUE;
445             break;
446           }
447 
448           answer = answer->next;
449         }
450 
451         DnsFreeQuestionChain(questions);
452         DnsFreeRRChain(answers);
453 
454         if (have)
455           break;
456 
457         // At this point we could try iterative resolution,
458         // but all end-user DNS servers would have tried that already
459 
460         "DnsParseResponse: no suitable answer in reply\n";
461         error = -1;
462       }
463       else {
464         "DnsParseResponse: error %d\n", error;
465       }
466     }
467 
468     if (++retries == DNS_MAX_RETRIES) {
469       "DnsRunQuery: max retries reached\n";
470       error = -1;
471       break;
472     }
473   }
474 
475   DnsFreeQuestion(&question);
476   return error;
477 }
478 
479 I64 DnsGetaddrinfo(U8* node, U8* service, addrinfo* hints, addrinfo** res) {
480   no_warn service;
481   no_warn hints;
482 
483   CDnsCacheEntry* cached = DnsCacheFind(node);
484 
485   if (cached) {
486     *res = MAlloc(sizeof(addrinfo));
487     AddrInfoCopy(*res, &cached->info);
488     (*res)->ai_flags |= AI_CACHED;
489     return 0;
490   }
491 
492   I64 sock = socket(AF_INET, SOCK_DGRAM);
493   I64 error = 0;
494 
495   if (sock >= 0) {
496     // TODO: service should be parsed as int, specifying port number
497     error = DnsRunQuery(sock, node, 0, res);
498 
499     close(sock);
500   }
501   else
502     error = -1;
503 
504   return error;
505 }
506 
507 U0 DnsSetResolverIPv4(U32 ip) {
508   dns_ip = ip;
509 }
510 
511 public U0 Host(U8* hostname) {
512   addrinfo* res = NULL;
513   I64 error = getaddrinfo(hostname, NULL, NULL, &res);
514 
515   if (error < 0) {
516     "getaddrinfo: error %d\n", error;
517   }
518   else {
519     addrinfo* curr = res;
520     while (curr) {
521       "flags %04Xh, family %d, socktype %d, proto %d, addrlen %d, addr %s\n",
522           curr->ai_flags, curr->ai_family, curr->ai_socktype, curr->ai_protocol, curr->ai_addrlen,
523           inet_ntoa((curr->ai_addr(sockaddr_in*))->sin_addr);
524       curr = curr->ai_next;
525     }
526   }
527 
528   freeaddrinfo(res);
529 }
530 
531 U0 DnsInit() {
532   static CAddrResolver dns_addr_resolver;
533   dns_addr_resolver.getaddrinfo = &DnsGetaddrinfo;
534 
535   socket_addr_resolver = &dns_addr_resolver;
536 }
537 
538 DnsInit;