Source file dns_client.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
open Dns
let src = Logs.Src.create "dns_client" ~doc:"DNS client"
module Log = (val Logs.src_log src : Logs.LOG)
module Pure = struct
type 'key query_state =
{ protocol : Dns.proto ;
key: 'key ;
query : Packet.t ;
} constraint 'key = 'a Rr_map.key
let make_query rng protocol ?(dnssec = false) edns hostname
: 'xy ->
string * 'xy query_state =
fun record_type ->
let edns = match edns with
| `None -> None
| `Manual e -> Some e
| `Auto -> match protocol with
| `Udp -> None
| `Tcp -> Some (Edns.create ~extensions:[Edns.Tcp_keepalive (Some 1200)] ())
in
let question = Packet.Question.create hostname record_type in
let =
let flags = Packet.Flags.singleton `Recursion_desired in
let flags =
if dnssec then Packet.Flags.add `Authentic_data flags else flags
in
Randomconv.int16 rng, flags
in
let query = Packet.create ?edns header question `Query in
Log.debug (fun m -> m "sending %a" Dns.Packet.pp query);
let cs , _ = Packet.encode protocol query in
begin match protocol with
| `Udp -> cs
| `Tcp ->
let len_field = Bytes.create 2 in
Bytes.set_uint16_be len_field 0 (String.length cs) ;
String.concat "" [Bytes.unsafe_to_string len_field ; cs]
end, { protocol ; query ; key = record_type }
let rec follow_cname name ~iterations:iterations_left ~answer ~state =
if iterations_left <= 0
then Error (`Msg "CNAME recursion too deep")
else
match Domain_name.Map.find_opt name answer with
| None -> Ok (`Need_soa name)
| Some relevant_map ->
match Rr_map.find state.key relevant_map with
| Some response -> Ok (`Data response)
| None ->
match Rr_map.(find Cname relevant_map) with
| None -> Error (`Msg "Invalid DNS response")
| Some (_ttl, redirected_host) ->
let iterations = pred iterations_left in
follow_cname redirected_host ~iterations ~answer ~state
let consume_protocol_prefix buf =
function
| `Udp -> Ok buf
| `Tcp ->
match String.get_uint16_be buf 0 with
| exception Invalid_argument _ -> Error ()
| pkt_len when pkt_len > String.length buf -2 ->
Log.debug (fun m -> m "Partial: %d >= %d-2"
pkt_len (String.length buf));
Error ()
| pkt_len ->
if 2 + pkt_len < String.length buf then
Log.warn (fun m -> m "Extraneous data in DNS response");
Ok (String.sub buf 2 pkt_len)
let find_soa authority =
Domain_name.Map.fold (fun k rr_map acc ->
match Rr_map.(find Soa rr_map) with
| Some soa -> Some (Domain_name.raw k, soa)
| None -> acc)
authority None
let distinguish_answer state =
let ( let* ) = Result.bind in
function
| `Answer (answer, authority) when not (Domain_name.Map.is_empty answer) ->
begin
let q = fst state.query.question in
let* o = follow_cname q ~iterations:20 ~answer ~state in
match o with
| `Data x -> Ok (`Data x)
| `Need_soa _name ->
match find_soa authority with
| Some soa -> Ok (`No_data soa)
| None -> Error (`Msg "invalid reply, couldn't find SOA")
end
| `Answer (_, authority) ->
begin match find_soa authority with
| Some soa -> Ok (`No_data soa)
| None -> Error (`Msg "invalid reply, no SOA in no data")
end
| `Rcode_error (Rcode.NXDomain, Opcode.Query, Some (_answer, authority)) ->
begin match find_soa authority with
| Some soa -> Ok (`No_domain soa)
| None -> Error (`Msg "invalid reply, no SOA in nodomain")
end
| r ->
Error (`Msg (Fmt.str "Ok %a, expected answer" Packet.pp_reply r))
let consume_rest_of_buffer state buf =
let to_msg t =
Result.map_error (fun e ->
`Msg
(Fmt.str
"QUERY: @[<v>hdr:%a (id: %d = %d) (q=q: %B)@ query:%a%a \
opt:%a tsig:%B@,failed: %a@,@]"
Packet.pp_header t
(fst t.header) (fst state.query.header)
(Packet.Question.compare t.question state.query.question = 0)
Packet.Question.pp t.question
Packet.pp_data t.data
(Fmt.option Dns.Edns.pp) t.edns
(match t.tsig with None -> false | Some _ -> true)
Packet.pp_mismatch e))
in
match Packet.decode buf with
| Error `Partial as e -> e
| Error err ->
Error (`Msg (Fmt.str "Error parsing response: %a" Packet.pp_err err))
| Ok t ->
Log.debug (fun m -> m "received %a" Dns.Packet.pp t);
to_msg t (Packet.reply_matches_request ~request:state.query t)
let parse_response (type requested)
: requested Rr_map.key query_state -> string ->
(Packet.reply,
[> `Partial
| `Msg of string]) result =
fun state buf ->
match consume_protocol_prefix buf state.protocol with
| Ok buf -> consume_rest_of_buffer state buf
| Error () -> Error `Partial
let handle_response (type requested)
: requested Rr_map.key query_state -> string ->
( [ `Data of requested
| `Partial
| `No_data of [`raw] Domain_name.t * Soa.t
| `No_domain of [`raw] Domain_name.t * Soa.t ],
[`Msg of string]) result =
fun state buf ->
match parse_response state buf with
| Error `Partial -> Ok `Partial
| Error `Msg _ as e -> e
| Ok reply -> distinguish_answer state reply
end
let default_resolver_hostname = Domain_name.(host_exn (of_string_exn "anycast.uncensoreddns.org"))
let default_resolvers = [
Ipaddr.of_string_exn "2001:67c:28a4::" ;
Ipaddr.of_string_exn "91.239.100.100" ;
]
module type S = sig
type context
type +'a io
type io_addr
type stack
type t
val create : ?nameservers:(Dns.proto * io_addr list) -> timeout:int64 -> stack -> t
val nameservers : t -> Dns.proto * io_addr list
val rng : int -> string
val clock : unit -> int64
val connect : t -> (Dns.proto * context, [> `Msg of string ]) result io
val send_recv : context -> string -> (string, [> `Msg of string ]) result io
val close : context -> unit io
val bind : 'a io -> ('a -> 'b io) -> 'b io
val lift : 'a -> 'a io
end
let localhost = Domain_name.of_string_exn "localhost"
let localsoa = Soa.create (Domain_name.prepend_label_exn localhost "ns")
let invalid = Domain_name.of_string_exn "invalid"
let invalidsoa = Soa.create (Domain_name.prepend_label_exn invalid "ns")
let rfc6761_special (type req) q_name (q_typ : req Dns.Rr_map.key) : (req Dns_cache.entry, unit) result =
if Domain_name.is_subdomain ~domain:localhost ~subdomain:q_name then
let open Dns.Rr_map in
match q_typ with
| A -> Ok (`Entry (300l, Ipaddr.V4.Set.singleton Ipaddr.V4.localhost))
| Aaaa ->
Ok (`Entry (300l, Ipaddr.V6.Set.singleton Ipaddr.V6.localhost))
| _ -> Ok (`No_domain (localhost, localsoa))
else if Domain_name.is_subdomain ~domain:invalid ~subdomain:q_name then
Ok (`No_domain (invalid, invalidsoa))
else
Error ()
module Make = functor (Transport:S) ->
struct
type t = {
mutable cache : Dns_cache.t ;
transport : Transport.t ;
edns : [ `None | `Auto | `Manual of Dns.Edns.t ] ;
}
let transport { transport ; _ } = transport
let create ?(cache_size = 32) ?(edns = `None) ?nameservers ?(timeout = Duration.of_sec 5) stack =
{ cache = Dns_cache.empty cache_size ;
transport = Transport.create ?nameservers ~timeout stack ;
edns ;
}
let nameservers { transport; _ } = Transport.nameservers transport
let (>>=) = Transport.bind
let (>>|) a b =
a >>= function
| Ok a' -> b a'
| Error e -> Transport.lift (Error e)
let (>>|=) a f = a >>| fun b -> Transport.lift (f b)
let lift_ok (type req) :
(req Dns_cache.entry, 'a) result ->
(req, [> `Msg of string
| `No_data of [ `raw ] Domain_name.t * Dns.Soa.t
| `No_domain of [ `raw ] Domain_name.t * Dns.Soa.t ]) result
= function
| Ok `Entry value -> Ok value
| Ok (`No_data _ as nodata) -> Error nodata
| Ok (`No_domain _ as nodom) -> Error nodom
| Ok (`Serv_fail _)
| Error _ -> Error (`Msg "")
let get_raw_reply t query_type name =
Transport.connect t.transport >>| fun (proto, socket) ->
Log.debug (fun m -> m "Connected to NS.");
let tx, state =
Pure.make_query Transport.rng proto ~dnssec:true t.edns name query_type
in
(Transport.send_recv socket tx >>| fun recv_buffer ->
Log.debug (fun m -> m "Read @[<v>%d bytes@]"
(String.length recv_buffer)) ;
Log.debug (fun m -> m "received: %a" (Ohex.pp_hexdump ()) recv_buffer);
Transport.lift (Pure.parse_response state recv_buffer)) >>= fun r ->
Transport.close socket >>= fun () ->
Transport.lift r
let get_resource_record (type requested) t (query_type:requested Dns.Rr_map.key) name
: (requested, [> `Msg of string
| `No_data of [ `raw ] Domain_name.t * Dns.Soa.t
| `No_domain of [ `raw ] Domain_name.t * Dns.Soa.t ]) result Transport.io =
let domain_name = Domain_name.raw name in
match rfc6761_special domain_name query_type |> lift_ok with
| Ok _ as ok -> Transport.lift ok
| Error ((`No_data _ | `No_domain _) as nod) -> Error nod |> Transport.lift
| Error `Msg _ ->
let cache', r =
Dns_cache.get t.cache (Transport.clock ()) domain_name query_type
in
t.cache <- cache';
match lift_ok (Result.map fst r) with
| Ok _ as ok -> Transport.lift ok
| Error ((`No_data _ | `No_domain _) as nod) -> Error nod |> Transport.lift
| Error `Msg _ ->
Transport.connect t.transport >>| fun (proto, socket) ->
Log.debug (fun m -> m "Connected to NS.");
let tx, state =
Pure.make_query Transport.rng proto t.edns name query_type
in
(Transport.send_recv socket tx >>| fun recv_buffer ->
Log.debug (fun m -> m "Read @[<v>%d bytes@]"
(String.length recv_buffer)) ;
let update_cache entry =
let rank = Dns_cache.NonAuthoritativeAnswer in
let cache =
Dns_cache.set t.cache (Transport.clock ()) domain_name query_type rank entry
in
t.cache <- cache
in
Transport.lift
(match Pure.handle_response state recv_buffer with
| Ok `Data x ->
update_cache (`Entry x);
Ok x
| Ok ((`No_data _ | `No_domain _) as nodom) ->
update_cache nodom;
Error nodom
| Error `Msg xxx -> Error (`Msg xxx)
| Ok `Partial -> Error (`Msg "Truncated UDP response"))) >>= fun r ->
Transport.close socket >>= fun () ->
Transport.lift r
let lift_cache_error query_type m =
(match m with
| Ok a -> Ok a
| Error `Msg msg -> Error (`Msg msg)
| Error (#Dns_cache.entry as e) ->
Error (`Msg (Fmt.str "DNS cache error @[%a@]" (Dns_cache.pp_entry query_type) e)))
|> Transport.lift
let getaddrinfo (type requested) t (query_type:requested Dns.Rr_map.key) name
: (requested, [> `Msg of string ]) result Transport.io =
get_resource_record t query_type name >>= lift_cache_error query_type
let gethostbyname stack domain =
getaddrinfo stack Dns.Rr_map.A domain >>|= fun (_ttl, resp) ->
match Ipaddr.V4.Set.choose_opt resp with
| None -> Error (`Msg "No A record found")
| Some ip -> Ok ip
let gethostbyname6 stack domain =
getaddrinfo stack Dns.Rr_map.Aaaa domain >>|= fun (_ttl, res) ->
match Ipaddr.V6.Set.choose_opt res with
| None -> Error (`Msg "No AAAA record found")
| Some ip -> Ok ip
end