Newer
Older
dns_server / main.ml
open Unix
open Netmask_table

type target_ip = {name: string; netmask: Int32.t; ip_int32: Int32.t; ip_str: string}

let reload_config : bool ref = ref false

(* let debug_resp = "#\xd3\x81\x80\x00\x01\x00\x01\x00\x00\x00\x00\x02vk\x03com\x00\x00\x01\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x00<\x00\x04\x05\x08\x06\x08" |> Bytes.of_string *)

let log msg =
  let {tm_sec = sec;
       tm_min = min;
       tm_hour = hour;
       tm_mday = mday;
       tm_mon = mon;
       tm_year = year;
       tm_wday = wday;
       tm_yday = yday;
       tm_isdst = isdst} = localtime (time ()) in
  Printf.sprintf "%04d-%02d-%02d %02d:%02d:%02d| %s" (year + 1900) (mon + 1) mday hour min sec msg
  |> print_endline 

let load_configuration () = 
  let make_associative_map conf =
    let hmap = Hashtbl.create (List.length conf) in
    let fill_map (domain, domain_result) =
      match Hashtbl.find_opt hmap domain with
      | Some lst -> Hashtbl.add hmap domain (domain_result::lst)
      | None -> Hashtbl.add hmap domain [domain_result]
    in
    conf |>
      List.filter_map (fun (c: Configuration.lineresult) ->
          match c with
          | No st -> assert false
          | Skip -> None
          | Ok (domain, ip, cidr) ->
             let domain_result = {name=domain; netmask=Netmask_table.nmask_of_cidr cidr; ip_int32=(int_of_ip ip); ip_str=ip} in
             Some (domain, domain_result)
        ) |> List.iter fill_map;
    hmap
  in
  let conf_file = if (Array.length Sys.argv) == 2 then Sys.argv.(1) else "/etc/dns.conf" in
  let conf = Configuration.parse_config_file conf_file in
  match List.filter_map (fun (c: Configuration.lineresult) -> match c with | No st -> Some st | _ -> None) conf with
  | (_::_) as _errors ->
     log "Error in configuration file" ; (* TODO: could display lines *)
     None
  | [] ->
     log "Reloaded configuration" ;
     Some (make_associative_map conf)
   

let domain_of _end msg =
  let rec _domain_of accum _end msg =
    let len = (Bytes.get msg 0 |> Char.code) in
    if len = 0 then
      accum |> List.rev_map Bytes.to_string |> String.concat "."
    else
      let part = Bytes.sub msg 1 len in
      let len = len + 1 in  (* discard what was just read *)
      let _end = _end - len in
      Bytes.sub msg len _end |> _domain_of (part::accum) _end
  in Bytes.sub msg 12 _end |> _domain_of [] _end (* |> String.escaped *)

let id_of msg = Bytes.get_uint16_be msg 0

let build_error_reply req =
  let open Bytes in
  let reply = make 12 '\x00' in
  blit req 0 reply 0 2;
  set reply 2 '\x81';
  set reply 3 '\x83';
  (reply, 12)

let build_reply src_sock req len_req target_ip =
  let open Bytes in
  let len_reply = len_req + 12 + 4 in
  let reply = make len_reply '\x00' in
  (* blit src srcoff dst dstoff len *)
  (* copy ID *)
  blit req 0 reply 0 2;
  set reply 2 '\x81';
  set reply 3 '\x80';
  (* copy Question and Answer counts *)
  blit req 4 reply 4 2;
  blit req 4 reply 6 2;
  blit_string "\x00\x00\x00\x00" 0 reply 8 4;
  (* Copy domain name *)
  blit req 12 reply 12 (len_req - 12);
  (* Put pointer to domain name 
     and put resp type, ttl, resource data length (4 bytes) *)
  let fl = "\xc0\x0c" ^ "\x00\x01\x00\x01\x00\x00\x00\x3c\x00\x04" in
  blit_string fl 0 reply len_req (String.length fl);
  (* Put ip *)
  set_int32_be reply (len_reply - 4) target_ip;
  (* return *)
  (reply, len_reply)

let rec loop_on domains sock =
  let domains = 
    if !reload_config = true
    then 
      reload_config := false;
      match load_configuration () with
      | Some new_domains -> new_domains
      | None ->
         log "Error in configuration file" ;
         domains 
    else
      domains
  in
  let req = Bytes.make 128 '\x00' in
  try
    let (len_req, src) = recvfrom sock req 0 120 [] in

    let (src_ip, src_ip_str) = match src with
      | Unix.ADDR_UNIX _ -> assert false
      | Unix.ADDR_INET (addr, _port) -> let ip = addr |> Unix.string_of_inet_addr in (ip |> int_of_ip, ip)
    in
    let target_domain = req |> domain_of len_req in
    let ip_i32 = match Hashtbl.find_opt domains target_domain with
      | None ->
         log ("No such domain: "^target_domain) ;
         None
      | Some ips ->
         let ip_match ip1 ip2 nm = (Int32.logand ip1 nm) = (Int32.logand ip2 nm) in
         match List.find_opt (fun d -> ip_match src_ip d.ip_int32 d.netmask) ips with
         | None ->
            Printf.sprintf "No match of %s for %s" target_domain src_ip_str |> log ;
            None
         | Some ip ->
            Printf.sprintf "match %s -> %s for %s" target_domain ip.ip_str src_ip_str |> log ;
            Some ip.ip_int32
    in
    let (reply, len_reply) =
      match ip_i32 with
      | Some ip_i32 -> build_reply src req len_req ip_i32
      | None -> build_error_reply req
    in
    let _total = sendto sock reply 0 len_reply [] src in
    loop_on domains sock
  with
  | Unix.Unix_error (_errno, _syscall, _) -> loop_on domains sock (* Ignore error and reload config *)
let _ =
  match load_configuration () with
  | None -> exit 1
  | Some domains ->
     let _ = Sys.signal Sys.sigusr1 (Signal_handle (fun _ -> reload_config := true)) in
     let sock = socket PF_INET SOCK_DGRAM 0 in
     setsockopt sock SO_REUSEADDR true;
     bind sock (ADDR_INET (inet_addr_of_string "0.0.0.0", 53));
     loop_on domains sock