package nbd

  1. Overview
  2. Docs

Source file mux.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
(*
 * Copyright (C) Citrix Systems Inc.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published
 * by the Free Software Foundation; version 2.1 only. with the special
 * exception on linking described in file LICENSE.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *)

(** Lwt connection multiplexer. Multiplexes between parallel requests from
    multiple clients over a single output channel to a server that may send
    responses out of order. Each request and response carries an [id] that is
    used to match responses to requests. *)

open Lwt.Infix

module type RPC = sig
  (** The transport mechanism used to send and receive messages *)
  type transport

  (** Each [request_hdr] and [response_hdr] carries an [id] that is used to
      match responses to requests. *)
  type id

  type request_hdr

  type request_body

  type response_hdr

  type response_body

  val recv_hdr : transport -> (id option * response_hdr) Lwt.t

  val recv_body :
       transport
    -> request_hdr
    -> response_hdr
    -> response_body
    -> (unit, Protocol.Error.t) result Lwt.t
  (** [recv_body transport request_hdr response_hdr response_body] returns [Ok ()]
      and receives and writes the body of the response into [response_body] if
      the request has been successful, otherwise returns an [Error]. The
      [request_hdr] parameter is the output of a preceding [recv_hdr] call. *)

  val send_one : transport -> request_hdr -> request_body -> unit Lwt.t
  (** Send a single request. Invocations of this function will not be interleaved
      because they are protected by a mutex *)

  val id_of_request : request_hdr -> id

  val handle_unrequested_packet : transport -> response_hdr -> unit Lwt.t
end

module Make (R : RPC) : sig
  type client

  val rpc :
       R.request_hdr
    -> R.request_body
    -> R.response_body
    -> client
    -> (unit, Protocol.Error.t) Result.t Lwt.t
  (** [rpc req_hdr req_body response_body client] sends a request to the server, and
      saves the response into [response_body]. Will block until a response to
      this request is received from the server. *)

  val create : R.transport -> client Lwt.t
  (** [create transport] creates a new client that manages parallel requests
      over the given transport channel. All communication over this channel
      must go through the returned client. *)
end = struct
  exception Unexpected_id of R.id

  exception Shutdown

  type client = {
      transport: R.transport
    ; outgoing_mutex: Lwt_mutex.t
    ; id_to_wakeup:
        ( R.id
        , R.request_hdr
          * (unit, Protocol.Error.t) result Lwt.u
          * R.response_body
        )
        Hashtbl.t
    ; mutable dispatcher_thread: unit Lwt.t
    ; mutable dispatcher_shutting_down: bool
  }

  let rec dispatcher t =
    let th =
      Lwt.catch
        (fun () ->
          R.recv_hdr t.transport >>= fun (id, pkt) ->
          match id with
          | None ->
              R.handle_unrequested_packet t.transport pkt
          | Some id ->
              if not (Hashtbl.mem t.id_to_wakeup id) then
                Lwt.fail (Unexpected_id id)
              else
                let request_hdr, waker, response_body =
                  Hashtbl.find t.id_to_wakeup id
                in
                R.recv_body t.transport request_hdr pkt response_body
                >>= fun response ->
                Lwt.wakeup waker response ;
                Hashtbl.remove t.id_to_wakeup id ;
                Lwt.return ()
        )
        (fun e ->
          t.dispatcher_shutting_down <- true ;
          Hashtbl.iter
            (fun _ (_, u, _) -> Lwt.wakeup_later_exn u e)
            t.id_to_wakeup ;
          Lwt.fail e
        )
    in
    th >>= fun () -> dispatcher t

  let rpc req_hdr req_body response_body t =
    let sleeper, waker = Lwt.wait () in
    if t.dispatcher_shutting_down then
      Lwt.fail Shutdown
    else
      let id = R.id_of_request req_hdr in
      Hashtbl.add t.id_to_wakeup id (req_hdr, waker, response_body) ;
      Lwt_mutex.with_lock t.outgoing_mutex (fun () ->
          R.send_one t.transport req_hdr req_body
      )
      >>= fun () -> sleeper

  let create transport =
    let t =
      {
        transport
      ; outgoing_mutex= Lwt_mutex.create ()
      ; id_to_wakeup= Hashtbl.create 10
      ; dispatcher_thread= Lwt.return ()
      ; dispatcher_shutting_down= false
      }
    in
    t.dispatcher_thread <- dispatcher t ;
    Lwt.return t
end
OCaml

Innovation. Community. Security.