package tcpip

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file tcpv6_socket.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
(*
 * Copyright (c) 2014 Anil Madhavapeddy <anil@recoil.org>
 * Copyright (c) 2014 Nicolas Ojeda Bar <n.oje.bar@gmail.com>
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *)

let src = Logs.Src.create "tcpv6-socket" ~doc:"TCP socket v6 (platform native)"
module Log = (val Logs.src_log src : Logs.LOG)

open Lwt.Infix

type ipaddr = Ipaddr.V6.t
type flow = Lwt_unix.file_descr

type t = {
  interface: Unix.inet_addr;    (* source ip to bind to *)
  mutable active_connections : Lwt_unix.file_descr list;
  listen_sockets : (int, Lwt_unix.file_descr) Hashtbl.t;
  mutable switched_off : unit Lwt.t;
}

let set_switched_off t switched_off = t.switched_off <- switched_off

include Tcp_socket

let connect addr =
  let ip =
    match addr with
    | None -> Ipaddr.V6.unspecified
    | Some ip -> Ipaddr.V6.Prefix.address ip
  in
  Lwt.return {
    interface = Ipaddr_unix.V6.to_inet_addr ip;
    active_connections = [];
    listen_sockets = Hashtbl.create 7;
    switched_off = Lwt.return_unit
  }

let disconnect t =
  Lwt_list.iter_p close t.active_connections >>= fun () ->
  Lwt_list.iter_p close
    (Hashtbl.fold (fun _ fd acc -> fd :: acc) t.listen_sockets [])

let dst fd =
  match Lwt_unix.getpeername fd with
  | Unix.ADDR_UNIX _ ->
    raise (Failure "unexpected: got a unix instead of tcp sock")
  | Unix.ADDR_INET (ia,port) -> begin
      match Ipaddr_unix.V6.of_inet_addr ia with
      | None -> raise (Failure "got a ipv4 sock instead of a tcpv6 one")
      | Some ip -> ip,port
    end

let create_connection ?keepalive t (dst,dst_port) =
  let fd = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in
  Lwt_unix.(setsockopt fd IPV6_ONLY true);
  Lwt.catch (fun () ->
      Lwt_unix.bind fd (Lwt_unix.ADDR_INET (t.interface, 0)) >>= fun () ->
      Lwt_unix.connect fd
        (Lwt_unix.ADDR_INET ((Ipaddr_unix.V6.to_inet_addr dst), dst_port))
      >>= fun () ->
      ( match keepalive with
        | None -> ()
        | Some { Mirage_protocols.Keepalive.after; interval; probes } ->
          Tcp_socket_options.enable_keepalive ~fd ~after ~interval ~probes );
      t.active_connections <- fd :: t.active_connections;
      Lwt.return (Ok fd))
    (fun exn ->
       close fd >>= fun () ->
       Lwt.return (Error (`Exn exn)))

let unlisten t ~port =
  match Hashtbl.find_opt t.listen_sockets port with
  | None -> ()
  | Some fd ->
    Hashtbl.remove t.listen_sockets port;
    try Unix.close (Lwt_unix.unix_file_descr fd) with _ -> ()

let listen t ~port ?keepalive callback =
  if port < 0 || port > 65535 then
    raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port));
  unlisten t ~port;
  let fd = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in
  Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true;
  Lwt_unix.(setsockopt fd IPV6_ONLY true);
  Unix.bind (Lwt_unix.unix_file_descr fd) (Lwt_unix.ADDR_INET (t.interface, port));
  Hashtbl.replace t.listen_sockets port fd;
  Lwt_unix.listen fd 10;
  (* FIXME: we should not ignore the result *)
  Lwt.async (fun () ->
      (* TODO cancellation *)
      let rec loop () =
        if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
        Lwt.catch (fun () ->
            Lwt_unix.accept fd >|= fun (afd, _) ->
            t.active_connections <- afd :: t.active_connections;
            (match keepalive with
             | None -> ()
             | Some { Mirage_protocols.Keepalive.after; interval; probes } ->
               Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes);
            Lwt.async
              (fun () ->
                 Lwt.catch
                   (fun () -> callback afd)
                   (fun exn ->
                      Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ;
                      close afd));
            `Continue)
          (function
            | Unix.Unix_error (Unix.EBADF, _, _) ->
              Log.warn (fun m -> m "error bad file descriptor in accept") ;
              Lwt.return `Stop
            | exn ->
              Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ;
              Lwt.return `Continue) >>= function
        | `Continue -> loop ()
        | `Stop -> Lwt.return_unit
      in
      Lwt.catch loop ignore_canceled >>= fun () -> close fd)
OCaml

Innovation. Community. Security.