package hg_lib

  1. Overview
  2. Docs

Source file command_server.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
open Core
open Async
module Def_error = Deferred.Or_error

module Channel_IO : sig
  val run_command : Process.t -> string list -> Process.Output.t Def_error.t

  val read
    :  Reader.t
    -> [ `Message of [ `Error | `Output ] * string | `Result of int ] Def_error.t
end = struct
  let read_length child_stdout =
    let buf = Bytes.create 4 in
    Reader.really_read child_stdout buf
    >>| function
    | `Eof len ->
      Or_error.error
        "read_length: eof"
        (Bytes.To_string.sub ~pos:0 ~len buf)
        String.sexp_of_t
    | `Ok -> Ok (Binary_packing.unpack_unsigned_32_int_big_endian ~buf ~pos:0)
  ;;

  let read child_stdout =
    Reader.read_char child_stdout
    >>= function
    | `Eof -> Deferred.Or_error.error_string "unexpected eof while reading channel"
    | `Ok channel_char ->
      let channel =
        match channel_char with
        | 'o' -> Ok `Output
        | 'e' -> Ok `Error
        | 'r' -> Ok `Result
        (* these are part of the spec, but unsupported by this implementation *)
        (* | 'I' | 'L' | 'd' *)
        | _ -> Or_error.error "unsupported channel" channel_char Char.sexp_of_t
      in
      (match channel with
       | Error _ as err -> return err
       | Ok channel ->
         read_length child_stdout
         >>=? fun len ->
         let buf = Bytes.create len in
         Reader.really_read child_stdout buf
         >>| (function
           | `Eof len ->
             Or_error.error
               "eof while reading message"
               (channel, len, Bytes.To_string.sub buf ~pos:0 ~len)
               [%sexp_of: [ `Output | `Error | `Result ] * int * string]
           | `Ok ->
             (match channel with
              | (`Output | `Error) as channel ->
                Ok (`Message (channel, Bytes.to_string buf))
              | `Result ->
                Ok (`Result (Binary_packing.unpack_signed_32_int_big_endian ~buf ~pos:0)))))
  ;;

  let read_full child_stdout =
    let flatten outputs =
      let stdouts, stderrs =
        List.partition_map outputs ~f:(fun (channel, text) ->
          match channel with
          | `Output -> First text
          | `Error -> Second text)
      in
      String.concat stdouts, String.concat stderrs
    in
    let rec loop acc =
      read child_stdout
      >>=? function
      | `Message (channel, text) -> loop ((channel, text) :: acc)
      | `Result exit_code ->
        let stdout, stderr = flatten (List.rev acc) in
        let exit_status =
          if exit_code = 0 then Ok () else Error (`Exit_non_zero exit_code)
        in
        Def_error.return { Process.Output.stdout; stderr; exit_status }
    in
    loop []
  ;;

  let send_command child_stdin args =
    let command = String.concat args ~sep:"\000" in
    let buf = Bytes.create 4 in
    Binary_packing.pack_unsigned_32_int_big_endian ~buf ~pos:0 (String.length command);
    try_with
      ~run:`Schedule (* consider [~run:`Now] instead; see: https://wiki/x/ByVWF *)
      ~rest:`Log
      (* consider [`Raise] instead; see: https://wiki/x/Ux4xF *)
      (fun () ->
         Writer.write child_stdin "runcommand\n";
         Writer.write_bytes child_stdin buf;
         Writer.write child_stdin command;
         Writer.flushed child_stdin)
    >>| function
    | Ok _ as ok -> ok
    | Error exn ->
      Or_error.error
        "unable to write command; process is probably dead!"
        (args, exn)
        [%sexp_of: string list * exn]
  ;;

  let run_command process args =
    let child_stdin = Process.stdin process in
    send_command child_stdin args >>=? fun () -> read_full (Process.stdout process)
  ;;
end

type t = Process.t Throttle.Sequencer.t

let valid_hello ~accepted_encodings hello =
  let accepted_encodings =
    List.map accepted_encodings ~f:(function
      | `Ascii -> "ascii"
      | `Utf8 -> "UTF-8")
  in
  let attrs =
    List.filter_map (String.split ~on:'\n' hello) ~f:(fun line ->
      Option.map (String.lsplit2 ~on:':' line) ~f:(fun (name, data) ->
        String.strip name, String.strip data))
  in
  let check key ~f =
    match List.Assoc.find attrs ~equal:String.equal key with
    | None ->
      Or_error.error_s
        [%message "key not in attrs" (key : string) (attrs : (string * string) list)]
    | Some value -> f value
  in
  Or_error.combine_errors_unit
    [ check "capabilities" ~f:(fun vals ->
        let capabilities = String.split ~on:' ' vals in
        let is_runcommand value = String.equal "runcommand" (String.strip value) in
        if List.exists capabilities ~f:is_runcommand
        then Ok ()
        else
          Or_error.error_s
            [%message
              "capabilities don't include runcommand" (capabilities : string list)])
    ; check "encoding" ~f:(fun encoding ->
        if List.mem ~equal:String.equal accepted_encodings encoding
        then Ok ()
        else
          Or_error.error_s
            [%message
              "encoding unacceptable; this can be caused by incorrect locale settings, \
               check the output of the `locale` command"
                (accepted_encodings : string list)
                (encoding : string)])
    ]
;;

let%test _ =
  Result.is_ok
    (valid_hello
       ~accepted_encodings:[ `Utf8 ]
       "capabilities: getencoding runcommand\nencoding: UTF-8")
;;

let%test _ =
  Result.is_ok
    (valid_hello
       ~accepted_encodings:[ `Ascii ]
       "capabilities: getencoding runcommand\nencoding: ascii\n")
;;

let%test _ =
  Result.is_error
    (valid_hello
       ~accepted_encodings:[ `Utf8 ]
       "capabilities: getencoding runcommand\nencoding: ascii\n")
;;

let%test _ =
  Result.is_error
    (valid_hello
       ~accepted_encodings:[ `Ascii ]
       "capabilities: getencoding runcommand\nencoding: UTF-8")
;;

let%test _ =
  Result.is_error
    (valid_hello
       ~accepted_encodings:[ `Ascii ]
       "capabilities: getencoding\nencoding: ascii\n")
;;

let%expect_test "report both errors" =
  let open Expect_test_helpers_core in
  show_raise (fun () ->
    valid_hello
      ~accepted_encodings:[ `Utf8 ]
      "capabilities: getencoding\nencoding: ascii\n"
    |> ok_exn);
  [%expect
    {|
    (raised (
      ("capabilities don't include runcommand" (capabilities (getencoding)))
      ("encoding unacceptable; this can be caused by incorrect locale settings, check the output of the `locale` command"
       (accepted_encodings (UTF-8))
       (encoding ascii))))
  |}];
  return ()
;;

module Ssh = struct
  type t =
    { host : string
    ; user : string option
    ; options : string list
    }
end

let create ?env ?(hg_binary = "hg") ?config ~accepted_encodings ssh =
  let config =
    Option.value_map config ~default:[] ~f:(fun config ->
      List.concat_map config ~f:(fun (key, data) -> [ "--config"; key ^ "=" ^ data ]))
  in
  let prog, extra_args =
    match ssh with
    | None -> hg_binary, []
    | Some { Ssh.host; user; options } ->
      let user_string =
        match user with
        | None -> ""
        | Some user -> user ^ "@"
      in
      "ssh", options @ [ user_string ^ host; "--"; hg_binary ]
  in
  let args = extra_args @ [ "serve"; "--cmdserver"; "pipe" ] @ config in
  (match ssh with
   | None ->
     (* When running a local server, start it in the user's home directory. This makes it
        consistent with running a remote server. *)
     Monitor.try_with_or_error ~here:[%here] Sys.home_directory >>|? Option.return
   | Some _ -> return (Ok None))
  >>=? fun working_dir ->
  Process.create ?env ?working_dir ~prog ~args ()
  >>=? fun process ->
  let hello_result =
    Channel_IO.read (Process.stdout process)
    >>=? function
    | `Message (`Error, error) ->
      Deferred.Or_error.error_s [%message "replied on error channel" (error : string)]
    | `Result result ->
      Deferred.Or_error.error_s
        [%message "replied on result channel, expecting output channel" (result : int)]
    | `Message (`Output, text) -> return (valid_hello ~accepted_encodings text)
  in
  Deferred.Or_error.tag_arg
    hello_result
    "parsing hello from command server failed"
    (prog, args)
    [%sexp_of: string * string list]
  >>= function
  | Ok () -> Deferred.Or_error.return (Throttle.Sequencer.create process)
  | Error _ as err ->
    Process.send_signal process Signal.term;
    Process.collect_output_and_wait process
    >>| fun output ->
    Or_error.tag_arg err "Process output" output Process.Output.sexp_of_t
;;

let run_command t ~cwd args =
  Throttle.enqueue t (fun process ->
    Channel_IO.run_command process ("--cwd" :: cwd :: args))
;;

let destroy t =
  Throttle.enqueue t (fun process ->
    Deferred.ignore_m (Process.collect_output_and_wait process))
;;
OCaml

Innovation. Community. Security.