Source file dns_client.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
open Dns
module Pure = struct
type 'key query_state =
{ protocol : Dns.proto ;
key: 'key ;
query : Packet.t ;
} constraint 'key = 'a Rr_map.key
let make_query rng protocol hostname
: 'xy ->
Cstruct.t * 'xy query_state =
fun record_type ->
let question = Packet.Question.create hostname record_type in
let = Randomconv.int16 rng, Packet.Flags.singleton `Recursion_desired in
let query = Packet.create header question `Query in
let cs , _ = Packet.encode protocol query in
begin match protocol with
| `Udp -> cs
| `Tcp ->
let len_field = Cstruct.create 2 in
Cstruct.BE.set_uint16 len_field 0 (Cstruct.len cs) ;
Cstruct.concat [len_field ; cs]
end, { protocol ; query ; key = record_type }
let rec follow_cname name ~iterations:iterations_left ~answer ~state =
let open Rresult in
if iterations_left <= 0
then Error (`Msg "CNAME recursion too deep")
else
match Domain_name.Map.find_opt name answer with
| None -> Ok (`Need_soa name)
| Some relevant_map ->
match Rr_map.find state.key relevant_map with
| Some response -> Ok (`Data response)
| None ->
match Rr_map.(find Cname relevant_map) with
| None -> Error (`Msg "Invalid DNS response")
| Some (_ttl, redirected_host) ->
let iterations = pred iterations_left in
follow_cname redirected_host ~iterations ~answer ~state
let consume_protocol_prefix buf =
function
| `Udp -> Ok buf
| `Tcp ->
match Cstruct.BE.get_uint16 buf 0 with
| exception Invalid_argument _ -> Error ()
| pkt_len when pkt_len > Cstruct.len buf -2 ->
Logs.debug (fun m -> m "Partial: %d >= %d-2"
pkt_len (Cstruct.len buf));
Error ()
| pkt_len ->
if 2 + pkt_len < Cstruct.len buf then
Logs.warn (fun m -> m "Extraneous data in DNS response");
Ok (Cstruct.sub buf 2 pkt_len)
let find_soa authority =
Domain_name.Map.fold (fun k rr_map acc ->
match Rr_map.(find Soa rr_map) with
| Some soa -> Some (Domain_name.raw k, soa)
| None -> acc)
authority None
let consume_rest_of_buffer state buf =
let open Rresult in
let to_msg t = function
| Ok a -> Ok a
| Error e ->
R.error_msgf
"QUERY: @[<v>hdr:%a (id: %d = %d) (q=q: %B)@ query:%a%a \
opt:%a tsig:%B@,failed: %a@,@]"
Packet.pp_header t
(fst t.header) (fst state.query.header)
(Packet.Question.compare t.question state.query.question = 0)
Packet.Question.pp t.question
Packet.pp_data t.data
(Fmt.option Dns.Edns.pp) t.edns
(match t.tsig with None -> false | Some _ -> true)
Packet.pp_mismatch e
in
match Packet.decode buf with
| Error `Partial -> Ok `Partial
| Error err ->
Rresult.R.error_msgf "Error parsing response: %a" Packet.pp_err err
| Ok t ->
to_msg t (Packet.reply_matches_request ~request:state.query t)
>>= function
| `Answer (answer, authority) when not (Domain_name.Map.is_empty answer) ->
begin
let q = fst state.query.question in
follow_cname q ~iterations:20 ~answer ~state >>= function
| `Data x -> Ok (`Data x)
| `Need_soa _name ->
match find_soa authority with
| Some soa -> Ok (`No_data soa)
| None -> Error (`Msg "invalid reply, couldn't find SOA")
end
| `Answer (_, authority) ->
begin match find_soa authority with
| Some soa -> Ok (`No_data soa)
| None -> Error (`Msg "invalid reply, no SOA in no data")
end
| `Rcode_error (NXDomain, Query, Some (_answer, authority)) ->
begin match find_soa authority with
| Some soa -> Ok (`No_domain soa)
| None -> Error (`Msg "invalid reply, no SOA in nodomain")
end
| r ->
Error (`Msg (Fmt.strf "Ok %a, expected answer" Packet.pp_reply r))
let parse_response (type requested)
: requested Rr_map.key query_state -> Cstruct.t ->
( [ `Data of requested | `Partial
| `No_data of [`raw] Domain_name.t * Soa.t
| `No_domain of [`raw] Domain_name.t * Soa.t ],
[`Msg of string]) result =
fun state buf ->
match consume_protocol_prefix buf state.protocol with
| Ok buf -> consume_rest_of_buffer state buf
| Error () -> Ok `Partial
end
let stdlib_random n =
let b = Cstruct.create n in
for i = 0 to pred n do
Cstruct.set_uint8 b i (Random.int 256)
done;
b
let default_resolver = "91.239.100.100"
module type S = sig
type flow
type +'a io
type io_addr
type ns_addr = ([`TCP | `UDP]) * io_addr
type stack
type t
val create : ?rng:(int -> Cstruct.t) -> ?nameserver:ns_addr -> stack -> t
val nameserver : t -> ns_addr
val rng : t -> (int -> Cstruct.t)
val connect : ?nameserver:ns_addr -> t -> (flow, [> `Msg of string ]) result io
val send : flow -> Cstruct.t -> (unit, [> `Msg of string ]) result io
val recv : flow -> (Cstruct.t, [> `Msg of string ]) result io
val close : flow -> unit io
val bind : 'a io -> ('a -> 'b io) -> 'b io
val lift : 'a -> 'a io
end
let localhost = Domain_name.of_string_exn "localhost"
let localsoa = Soa.create (Domain_name.prepend_label_exn localhost "ns")
let invalid = Domain_name.of_string_exn "invalid"
let invalidsoa = Soa.create (Domain_name.prepend_label_exn invalid "ns")
let rfc6761_special (type req) q_name (q_typ : req Dns.Rr_map.key) : (Dns_cache.entry, unit) result =
if Domain_name.is_subdomain ~domain:localhost ~subdomain:q_name then
let open Dns.Rr_map in
match q_typ with
| A -> Ok (`Entry (B (A, (300l, Ipv4_set.singleton Ipaddr.V4.localhost))))
| Aaaa ->
Ok (`Entry (B (Aaaa, (300l, Ipv6_set.singleton Ipaddr.V6.localhost))))
| _ -> Ok (`No_domain (localhost, localsoa))
else if Domain_name.is_subdomain ~domain:invalid ~subdomain:q_name then
Ok (`No_domain (invalid, invalidsoa))
else
Error ()
module Make = functor (Transport:S) ->
struct
type t = {
cache : Dns_cache.t ;
clock : unit -> int64 ;
transport : Transport.t ;
}
let create ?(size=32) ?rng ?nameserver ~clock stack =
{ cache = Dns_cache.empty size ;
clock = clock ;
transport = Transport.create ?rng ?nameserver stack
}
let nameserver { transport; _ } = Transport.nameserver transport
let (>>=) = Transport.bind
let (>>|) a b =
a >>= function
| Ok a' -> b a'
| Error e -> Transport.lift (Error e)
let (>>|=) a f = a >>| fun b -> Transport.lift (f b)
let lift_ok (type req) (query_type : req Dns.Rr_map.key) :
(Dns_cache.entry, 'a) result ->
(req, [> `Msg of string
| `No_data of [ `raw ] Domain_name.t * Dns.Soa.t
| `No_domain of [ `raw ] Domain_name.t * Dns.Soa.t ]) result
= function
| Ok `Entry (Dns.Rr_map.B (query_type', value)) ->
begin match Dns.Rr_map.K.compare query_type' query_type with
| Gmap.Order.Eq -> Ok value
| _ ->
Rresult.R.error_msgf "should not happen request_type <> request_type'"
end
| Ok (`No_data _ as nodata) -> Error nodata
| Ok (`No_domain _ as nodom) -> Error nodom
| Ok (`Serv_fail _)
| Error _ -> Error (`Msg "")
let get_resource_record (type requested) t ?nameserver (query_type:requested Dns.Rr_map.key) name
: (requested, [> `Msg of string
| `No_data of [ `raw ] Domain_name.t * Dns.Soa.t
| `No_domain of [ `raw ] Domain_name.t * Dns.Soa.t ]) result Transport.io =
let domain_name = Domain_name.raw name in
match rfc6761_special domain_name query_type |> lift_ok query_type with
| Ok _ as ok -> Transport.lift ok
| Error ((`No_data _ | `No_domain _) as nod) -> Error nod |> Transport.lift
| Error `Msg _ ->
match Dns_cache.get t.cache (t.clock ()) domain_name query_type |> lift_ok query_type with
| Ok _ as ok -> Transport.lift ok
| Error ((`No_data _ | `No_domain _) as nod) -> Error nod |> Transport.lift
| Error `Msg _ ->
let proto, _ = match nameserver with
| None -> Transport.nameserver t.transport | Some x -> x in
let tx, state =
Pure.make_query (Transport.rng t.transport)
(match proto with `UDP -> `Udp | `TCP -> `Tcp) name query_type
in
Transport.connect ?nameserver t.transport >>| fun socket ->
Logs.debug (fun m -> m "Connected to NS.");
(Transport.send socket tx >>| fun () ->
Logs.debug (fun m -> m "Receiving from NS");
let update_cache entry =
let rank = Dns_cache.NonAuthoritativeAnswer in
Dns_cache.set t.cache (t.clock ()) domain_name query_type rank entry
in
let rec recv_loop acc =
Transport.recv socket >>| fun recv_buffer ->
Logs.debug (fun m -> m "Read @[<v>%d bytes@]"
(Cstruct.len recv_buffer)) ;
let buf =
if Cstruct.(equal empty acc)
then recv_buffer
else Cstruct.append acc recv_buffer
in
match Pure.parse_response state buf with
| Ok `Data x ->
update_cache (`Entry (Rr_map.B (query_type, x)));
Ok x |> Transport.lift
| Ok ((`No_data _ | `No_domain _) as nodom) ->
update_cache nodom;
Error nodom |> Transport.lift
| Error `Msg xxx -> Error (`Msg xxx) |> Transport.lift
| Ok `Partial when proto = `TCP -> recv_loop buf
| Ok `Partial -> Error (`Msg "Truncated UDP response") |> Transport.lift
in recv_loop Cstruct.empty) >>= fun r ->
Transport.close socket >>= fun () ->
Transport.lift r
let lift_cache_error m =
(match m with
| Ok a -> Ok a
| Error `Msg msg -> Error (`Msg msg)
| Error (#Dns_cache.entry as e) ->
Rresult.R.error_msgf "DNS cache error @[%a@]" Dns_cache.pp_entry e)
|> Transport.lift
let getaddrinfo (type requested) t ?nameserver (query_type:requested Dns.Rr_map.key) name
: (requested, [> `Msg of string ]) result Transport.io =
get_resource_record t ?nameserver query_type name >>= lift_cache_error
let gethostbyname stack ?nameserver domain =
getaddrinfo stack ?nameserver Dns.Rr_map.A domain >>|= fun (_ttl, resp) ->
match Dns.Rr_map.Ipv4_set.choose_opt resp with
| None -> Error (`Msg "No A record found")
| Some ip -> Ok ip
let gethostbyname6 stack ?nameserver domain =
getaddrinfo stack ?nameserver Dns.Rr_map.Aaaa domain >>|= fun (_ttl, res) ->
match Dns.Rr_map.Ipv6_set.choose_opt res with
| None -> Error (`Msg "No AAAA record found")
| Some ip -> Ok ip
end