package nbd

  1. Overview
  2. Docs

Source file client.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
(*
 * Copyright (C) Citrix Systems Inc.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published
 * by the Free Software Foundation; version 2.1 only. with the special
 * exception on linking described in file LICENSE.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *)

open Lwt.Infix
open Protocol
open Channel

type size = int64

let get_handle =
  let next = ref 0L in
  fun () ->
    let this = !next in
    next := Int64.succ !next ;
    this

module NbdRpc = struct
  type transport = channel

  type id = int64

  type request_hdr = Request.t

  type request_body = Cstruct.t option

  type response_hdr = Reply.t

  type response_body = Cstruct.t list

  let recv_hdr sock =
    let buf = Cstruct.create 16 in
    sock.read buf >>= fun () ->
    match Reply.unmarshal buf with
    | Ok x ->
        Lwt.return (Some x.Reply.handle, x)
    | Error e ->
        Lwt.fail e

  let recv_body sock req_hdr res_hdr response_body =
    match res_hdr.Reply.error with
    | Error e ->
        Lwt.return_error e
    | Ok () -> (
      match req_hdr.Request.ty with
      | Command.Read ->
          (* TODO: use a page-aligned memory allocator *)
          Lwt_list.iter_s sock.read response_body >>= fun () -> Lwt.return_ok ()
      | _ ->
          Lwt.return_ok ()
    )

  let send_one sock req_hdr req_body =
    let buf = Cstruct.create Request.sizeof in
    Request.marshal buf req_hdr ;
    sock.write buf >>= fun () ->
    match req_body with None -> Lwt.return () | Some data -> sock.write data

  let id_of_request req = req.Request.handle

  let handle_unrequested_packet _t reply =
    Lwt.fail_with
      (Printf.sprintf "Unexpected response from server: %s"
         (Reply.to_string reply)
      )
end

module Rpc = Mux.Make (NbdRpc)

type error = [Mirage_block.error | `Protocol_error of Protocol.Error.t]

type write_error =
  [Mirage_block.write_error | `Protocol_error of Protocol.Error.t]

let pp_error ppf = function
  | #Mirage_block.error as e ->
      Mirage_block.pp_error ppf e
  | `Protocol_error e ->
      Fmt.string ppf (Protocol.Error.to_string e)

let pp_write_error ppf = function
  | #Mirage_block.write_error as e ->
      Mirage_block.pp_write_error ppf e
  | `Protocol_error e ->
      Fmt.string ppf (Protocol.Error.to_string e)

type t = {
    client: Rpc.client
  ; info: Mirage_block.info
  ; mutable disconnected: bool
}

type id = unit

let make channel size_bytes flags =
  Rpc.create channel >>= fun client ->
  let read_write = not (List.mem PerExportFlag.Read_only flags) in
  let sector_size = 1 in
  (* Note: NBD has no notion of a sector *)
  let size_sectors = size_bytes in
  let info = {Mirage_block.read_write; sector_size; size_sectors} in
  let disconnected = false in
  Lwt.return {client; info; disconnected}

let list channel =
  let section = Lwt_log_core.Section.make "Client.list" in

  let buf = Cstruct.create Announcement.sizeof in
  channel.read buf >>= fun () ->
  match Announcement.unmarshal buf with
  | Error e ->
      Lwt.fail e
  | Ok kind -> (
      let buf = Cstruct.create (Negotiate.sizeof kind) in
      channel.read buf >>= fun () ->
      match Negotiate.unmarshal buf kind with
      | Error e ->
          Lwt.fail e
      | Ok (Negotiate.V1 _) ->
          Lwt.return_error `Unsupported
      | Ok (Negotiate.V2 x) ->
          let buf = Cstruct.create NegotiateResponse.sizeof in
          let flags =
            if List.mem GlobalFlag.Fixed_newstyle x then
              [ClientFlag.Fixed_newstyle]
            else
              []
          in
          NegotiateResponse.marshal buf flags ;
          channel.write buf >>= fun () ->
          let buf = Cstruct.create OptionRequestHeader.sizeof in
          OptionRequestHeader.(marshal buf {ty= Option.List; length= 0l}) ;
          channel.write buf >>= fun () ->
          let buf = Cstruct.create OptionResponseHeader.sizeof in
          let rec loop acc =
            channel.read buf >>= fun () ->
            match OptionResponseHeader.unmarshal buf with
            | Error e ->
                Lwt.fail e
            | Ok {OptionResponseHeader.response_type= OptionResponse.Ack; _} ->
                Lwt.return_ok acc
            | Ok {OptionResponseHeader.response_type= OptionResponse.Policy; _}
              ->
                Lwt.return_error `Policy
            | Ok
                {
                  OptionResponseHeader.response_type= OptionResponse.Server
                ; length
                ; _
                } -> (
                let buf' = Cstruct.create (Int32.to_int length) in
                channel.read buf' >>= fun () ->
                match Server.unmarshal buf' with
                | Ok server ->
                    loop (server.Server.name :: acc)
                | Error e ->
                    Lwt.fail e
              )
            | Ok _ ->
                Lwt.fail_with "Server's OptionResponse had an invalid type"
          in
          loop [] >>= fun result ->
          (* Send NBD_OPT_ABORT to terminate the option haggling *)
          let buf = Cstruct.create OptionRequestHeader.sizeof in
          OptionRequestHeader.(marshal buf {ty= Option.Abort; length= 0l}) ;
          channel.write buf >>= fun () ->
          (* The NBD protocol says: "the client SHOULD gracefully handle the
           * server closing the connection after receiving an NBD_OPT_ABORT
           * without it sending a reply" *)
          Lwt.catch
            (fun () ->
              (* Read ack from server *)
              let buf = Cstruct.create OptionResponseHeader.sizeof in
              channel.read buf >>= fun () ->
              match OptionResponseHeader.unmarshal buf with
              | Error e ->
                  Lwt.fail e
              | Ok {OptionResponseHeader.response_type= OptionResponse.Ack; _}
                ->
                  Lwt.return_unit
              | Ok _ ->
                  Lwt.fail_with "Server's OptionResponse had an invalid type"
            )
            (fun exn ->
              Lwt_log_core.warning ~section ~exn
                "Got exception while reading ack from server"
            )
          >|= fun () -> result
    )

let negotiate channel export =
  let buf = Cstruct.create Announcement.sizeof in
  channel.read buf >>= fun () ->
  match Announcement.unmarshal buf with
  | Error e ->
      Lwt.fail e
  | Ok kind -> (
      let buf = Cstruct.create (Negotiate.sizeof kind) in
      channel.read buf >>= fun () ->
      match Negotiate.unmarshal buf kind with
      | Error e ->
          Lwt.fail e
      | Ok (Negotiate.V1 x) ->
          make channel x.Negotiate.size x.Negotiate.flags >>= fun t ->
          Lwt.return (t, x.Negotiate.size, x.Negotiate.flags)
      | Ok (Negotiate.V2 x) -> (
          let buf = Cstruct.create NegotiateResponse.sizeof in
          let flags =
            if List.mem GlobalFlag.Fixed_newstyle x then
              [ClientFlag.Fixed_newstyle]
            else
              []
          in
          NegotiateResponse.marshal buf flags ;
          channel.write buf >>= fun () ->
          let buf = Cstruct.create OptionRequestHeader.sizeof in
          OptionRequestHeader.(
            marshal buf
              {
                ty= Option.ExportName
              ; length= Int32.of_int (String.length export)
              }
          ) ;
          channel.write buf >>= fun () ->
          let buf = Cstruct.create (ExportName.sizeof export) in
          ExportName.marshal buf export ;
          channel.write buf >>= fun () ->
          let buf = Cstruct.create DiskInfo.sizeof in
          channel.read buf >>= fun () ->
          match DiskInfo.unmarshal buf with
          | Error e ->
              Lwt.fail e
          | Ok x ->
              make channel x.DiskInfo.size x.DiskInfo.flags >>= fun t ->
              Lwt.return (t, x.DiskInfo.size, x.DiskInfo.flags)
        )
    )

let get_info t = Lwt.return t.info

let write_one t from buffer =
  let handle = get_handle () in
  let req_hdr =
    {
      Request.ty= Command.Write
    ; handle
    ; from
    ; len= Int32.of_int (Cstruct.length buffer)
    }
  in
  Rpc.rpc req_hdr (Some buffer) [] t.client

let write t from buffers =
  if t.disconnected then
    Lwt.return_error `Disconnected
  else
    let rec loop from = function
      | [] ->
          Lwt.return_ok ()
      | b :: bs -> (
          write_one t from b >>= function
          | Ok () ->
              loop Int64.(add from (of_int (Cstruct.length b))) bs
          | Error e ->
              Lwt.return_error e
        )
    in
    loop from buffers >>= function
    | Error e ->
        Lwt.return_error (`Protocol_error e)
    | Ok () ->
        Lwt.return_ok ()

let read t from buffers =
  if t.disconnected then
    Lwt.return_error `Disconnected
  else
    let handle = get_handle () in
    let len =
      Int32.of_int @@ List.fold_left ( + ) 0 @@ List.map Cstruct.length buffers
    in
    let req_hdr = {Request.ty= Command.Read; handle; from; len} in
    let req_body = None in
    Rpc.rpc req_hdr req_body buffers t.client >>= function
    | Error e ->
        Lwt.return_error (`Protocol_error e)
    | Ok () ->
        Lwt.return_ok ()

let disconnect t =
  t.disconnected <- true ;
  Lwt.return ()
OCaml

Innovation. Community. Security.