package mqtt

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file Mqtt_client.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
let fmt = Format.asprintf

type connection = Lwt_io.input_channel * Lwt_io.output_channel

let decode_length inch =
  let rec loop value mult =
    let%lwt ch = Lwt_io.read_char inch in
    let ch = Char.code ch in
    let digit = ch land 127 in
    let value = value + (digit * mult) in
    let mult = mult * 128 in
    if ch land 128 = 0 then Lwt.return value else loop value mult
  in
  loop 0 1

let read_packet inch =
  let%lwt header_byte = Lwt_io.read_char inch in
  let msgid, opts =
    Mqtt_packet.Decoder.decode_fixed_header (Char.code header_byte)
  in
  let%lwt count = decode_length inch in

  let data = Bytes.create count in
  let%lwt () =
    try Lwt_io.read_into_exactly inch data 0 count
    with End_of_file -> Lwt.fail (Failure "could not read bytes")
  in
  let pkt =
    Read_buffer.make (data |> Bytes.to_string)
    |> Mqtt_packet.Decoder.decode_packet opts msgid
  in
  Lwt.return (opts, pkt)

module Log = (val Logs_lwt.src_log (Logs.Src.create "mqtt.client"))

type t = {
  cxn : connection;
  id : string;
  inflight : (int, unit Lwt_condition.t * Mqtt_packet.t) Hashtbl.t;
  mutable reader : unit Lwt.t;
  on_message : topic:string -> string -> unit Lwt.t;
  on_disconnect : t -> unit Lwt.t;
  on_error : t -> exn -> unit Lwt.t;
  should_stop_reader : unit Lwt_condition.t;
}

let wrap_catch client f = Lwt.catch f (client.on_error client)

let default_on_error client exn =
  let%lwt () =
    Log.err (fun log ->
        log "[%s]: Unhandled exception: %a" client.id Fmt.exn exn)
  in
  Lwt.fail exn

let default_on_message ~topic:_ _ = Lwt.return_unit
let default_on_disconnect _ = Lwt.return_unit

let read_packets client =
  let in_chan, out_chan = client.cxn in

  let ack_inflight id pkt =
    try
      let cond, expected_ack_pkt = Hashtbl.find client.inflight id in
      if pkt = expected_ack_pkt then (
        Hashtbl.remove client.inflight id;
        Lwt_condition.signal cond ();
        Lwt.return_unit)
      else Lwt.fail (Failure "unexpected packet in ack")
    with Not_found -> Lwt.fail (Failure (fmt "ack for id=%d not found" id))
  in

  let rec loop () =
    let%lwt (_dup, qos, _retain), packet = read_packet in_chan in
    let%lwt () =
      match packet with
      (* Publish with QoS 0: push *)
      | Publish (None, topic, payload) when qos = Atmost_once ->
        client.on_message ~topic payload
      (* Publish with QoS 0 and packet identifier: error *)
      | Publish (Some _id, _topic, _payload) when qos = Atmost_once ->
        Lwt.fail
          (Failure
             "protocol violation: publish packet with qos 0 must not have id")
      (* Publish with QoS 1 *)
      | Publish (Some id, topic, payload) when qos = Atleast_once ->
        (* - Push the message to the consumer queue.
           - Send back the PUBACK packet. *)
        let%lwt () = client.on_message ~topic payload in
        let puback = Mqtt_packet.Encoder.puback id in
        Lwt_io.write out_chan puback
      | Publish (None, _topic, _payload) when qos = Atleast_once ->
        Lwt.fail
          (Failure
             "protocol violation: publish packet with qos > 0 must have id")
      | Publish _ ->
        Lwt.fail (Failure "not supported publish packet (probably qos 2)")
      | Suback (id, _)
      | Unsuback id
      | Puback id
      | Pubrec id
      | Pubrel id
      | Pubcomp id ->
        ack_inflight id packet
      | Pingresp -> Lwt.return_unit
      | _ -> Lwt.fail (Failure "unknown packet from server")
    in
    loop ()
  in

  let%lwt () =
    Log.debug (fun log -> log "[%s] Starting reader loop..." client.id)
  in
  Lwt.pick
    [
      (let%lwt () = Lwt_condition.wait client.should_stop_reader in
       Log.info (fun log -> log "[%s] Stopping reader loop..." client.id));
      loop ();
    ]

let disconnect client =
  let%lwt () =
    Log.info (fun log -> log "[%s] Disconnecting client..." client.id)
  in
  let _, oc = client.cxn in
  Lwt_condition.signal client.should_stop_reader ();
  let%lwt () = Lwt_io.write oc (Mqtt_packet.Encoder.disconnect ()) in
  let%lwt () = client.on_disconnect client in
  Log.info (fun log -> log "[%s] Client disconnected." client.id)

let shutdown client =
  let%lwt () =
    Log.debug (fun log -> log "[%s] Shutting down the connection..." client.id)
  in
  let ic, oc = client.cxn in
  let%lwt () = Lwt_io.flush oc in
  let%lwt () = Lwt_io.close ic in
  let%lwt () = Lwt_io.close oc in
  Log.debug (fun log -> log "[%s] Client connection shut down." client.id)

let open_tls_connection ~client_id ~ca_file host port =
  try%lwt
    let%lwt authenticator = X509_lwt.authenticator (`Ca_file ca_file) in
    Tls_lwt.connect authenticator (host, port)
  with exn ->
    let%lwt () =
      Log.err (fun log ->
          log "[%s] could not get address info for %S" client_id host)
    in
    Lwt.fail exn

let run_pinger ~keep_alive client =
  let%lwt () = Log.debug (fun log -> log "Starting ping timer...") in
  let _, output = client.cxn in
  (* 25% leeway *)
  let keep_alive = 0.75 *. float_of_int keep_alive in
  let rec loop () =
    let%lwt () = Lwt_unix.sleep keep_alive in
    let pingreq_packet = Mqtt_packet.Encoder.pingreq () in
    let%lwt () = Lwt_io.write output pingreq_packet in
    loop ()
  in
  loop ()

exception Connection_error

let open_tcp_connection ~client_id host port =
  let%lwt addresses = Lwt_unix.getaddrinfo host (string_of_int port) [] in
  match addresses with
  | address :: _ ->
    let sockaddr = Lwt_unix.(address.ai_addr) in
    Lwt_io.open_connection sockaddr
  | _ ->
    let%lwt () =
      Log.err (fun log ->
          log "[%s] could not get address info for %S" client_id host)
    in
    Lwt.fail Connection_error

let rec create_connection ?tls_ca ~port ~client_id hosts =
  match hosts with
  | [] ->
    let%lwt () =
      Log.err (fun log ->
          log "[%s] Could not connect to any of the hosts (on port %d): %a"
            client_id port
            Fmt.Dump.(list string)
            hosts)
    in
    Lwt.fail Connection_error
  | host :: hosts -> (
    try%lwt
      let%lwt () =
        Log.debug (fun log ->
            log "[%s] Connecting to `%s:%d`..." client_id host port)
      in
      let%lwt connection =
        match tls_ca with
        | Some ca_file -> open_tls_connection ~client_id ~ca_file host port
        | None -> open_tcp_connection ~client_id host port
      in
      let%lwt () =
        Log.info (fun log ->
            log "[%s] Connection opened on `%s:%d`." client_id host port)
      in
      Lwt.return connection
    with _ ->
      let%lwt () =
        Log.debug (fun log ->
            log "[%s] Could not connect, trying next host..." client_id)
      in
      create_connection ?tls_ca ~port ~client_id hosts)

let connect ?(id = "ocaml-mqtt") ?tls_ca ?credentials ?will
    ?(clean_session = true) ?(keep_alive = 30)
    ?(on_message = default_on_message) ?(on_disconnect = default_on_disconnect)
    ?(on_error = default_on_error) ?(port = 1883) hosts =
  let flags =
    if clean_session || id = "" then [ Mqtt_packet.Clean_session ] else []
  in
  let cxn_data =
    { Mqtt_packet.clientid = id; credentials; will; flags; keep_alive }
  in

  let%lwt ((ic, oc) as connection) =
    create_connection ?tls_ca ~port ~client_id:id hosts
  in

  let connect_packet =
    Mqtt_packet.Encoder.connect ?credentials:cxn_data.credentials
      ?will:cxn_data.will ~flags:cxn_data.flags ~keep_alive:cxn_data.keep_alive
      cxn_data.clientid
  in
  let%lwt () = Lwt_io.write oc connect_packet in
  let inflight = Hashtbl.create 16 in

  match%lwt read_packet ic with
  | _, Connack { connection_status = Accepted; session_present } ->
    let%lwt () =
      Log.debug (fun log ->
          log "[%s] Connection acknowledged (session_present=%b)" id
            session_present)
    in

    let client =
      {
        cxn = connection;
        id;
        inflight;
        reader = Lwt.return_unit;
        should_stop_reader = Lwt_condition.create ();
        on_message;
        on_disconnect;
        on_error;
      }
    in

    Lwt.async (fun () ->
        client.reader <- wrap_catch client (fun () -> read_packets client);
        let%lwt () =
          Log.debug (fun log -> log "[%s] Packet reader started." client.id)
        in
        let%lwt () =
          Lwt.pick [ client.reader; run_pinger ~keep_alive client ]
        in
        let%lwt () =
          Log.debug (fun log ->
              log "[%s] Packet reader stopped, shutting down..." client.id)
        in
        shutdown client);

    Lwt.return client
  | _, Connack pkt ->
    let conn_status =
      Mqtt_packet.connection_status_to_string pkt.connection_status
    in
    let%lwt () =
      Log.err (fun log -> log "[%s] Connection failed: %s" id conn_status)
    in
    Lwt.fail Connection_error
  | _ ->
    let%lwt () =
      Log.err (fun log ->
          log "[%s] Invalid response from broker on connection" id)
    in
    Lwt.fail Connection_error

let publish ?(dup = false) ?(qos = Mqtt_core.Atleast_once) ?(retain = false)
    ~topic payload client =
  let _, oc = client.cxn in
  match qos with
  | Atmost_once ->
    let pkt_data =
      Mqtt_packet.Encoder.publish ~dup ~qos ~retain ~id:0 ~topic payload
    in
    Lwt_io.write oc pkt_data
  | Atleast_once ->
    let id = Mqtt_packet.gen_id () in
    let cond = Lwt_condition.create () in
    let expected_ack_pkt = Mqtt_packet.puback id in
    Hashtbl.add client.inflight id (cond, expected_ack_pkt);
    let pkt_data =
      Mqtt_packet.Encoder.publish ~dup ~qos ~retain ~id ~topic payload
    in
    let%lwt () = Lwt_io.write oc pkt_data in
    Lwt_condition.wait cond
  | Exactly_once ->
    let id = Mqtt_packet.gen_id () in
    let cond = Lwt_condition.create () in
    let expected_ack_pkt = Mqtt_packet.pubrec id in
    Hashtbl.add client.inflight id (cond, expected_ack_pkt);
    let pkt_data =
      Mqtt_packet.Encoder.publish ~dup ~qos ~retain ~id ~topic payload
    in
    let%lwt () = Lwt_io.write oc pkt_data in
    let%lwt () = Lwt_condition.wait cond in
    let expected_ack_pkt = Mqtt_packet.pubcomp id in
    Hashtbl.add client.inflight id (cond, expected_ack_pkt);
    let pkt_data = Mqtt_packet.Encoder.pubrel id in
    let%lwt () = Lwt_io.write oc pkt_data in
    Lwt_condition.wait cond

let subscribe topics client =
  if topics = [] then raise (Invalid_argument "empty topics");
  let _, oc = client.cxn in
  let pkt_id = Mqtt_packet.gen_id () in
  let subscribe_packet = Mqtt_packet.Encoder.subscribe ~id:pkt_id topics in
  let qos_list = List.map (fun (_, q) -> Ok q) topics in
  let cond = Lwt_condition.create () in
  Hashtbl.add client.inflight pkt_id (cond, Suback (pkt_id, qos_list));
  wrap_catch client (fun () ->
      let%lwt () = Lwt_io.write oc subscribe_packet in
      let%lwt () = Lwt_condition.wait cond in
      let topics = List.map fst topics in
      Log.info (fun log ->
          log "[%s] Subscribed to %a." client.id Fmt.Dump.(list string) topics))

include Mqtt_core
OCaml

Innovation. Community. Security.