package torch

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file squeezenet.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
open Base
open Torch

let fire vs in_planes squeeze_planes exp1_planes exp3_planes =
  let squeeze =
    Layer.conv2d_
      (Var_store.sub vs "squeeze")
      ~ksize:1
      ~stride:1
      ~input_dim:in_planes
      squeeze_planes
  in
  let exp1 =
    Layer.conv2d_
      (Var_store.sub vs "expand1x1")
      ~ksize:1
      ~stride:1
      ~input_dim:squeeze_planes
      exp1_planes
  in
  let exp3 =
    Layer.conv2d_
      (Var_store.sub vs "expand3x3")
      ~ksize:3
      ~stride:1
      ~padding:1
      ~input_dim:squeeze_planes
      exp3_planes
  in
  Layer.of_fn (fun xs ->
    let xs = Layer.forward squeeze xs |> Tensor.relu_ in
    Tensor.concat
      ~dim:1
      [ Layer.forward exp1 xs |> Tensor.relu_; Layer.forward exp3 xs |> Tensor.relu_ ])
;;

let squeezenet vs ~version ~num_classes =
  let features =
    let sub_vs i = Var_store.(vs / "features" / Int.to_string i) in
    match version with
    | `v1_0 ->
      Layer.sequential
        [ Layer.conv2d_ (sub_vs 0) ~ksize:7 ~stride:2 ~input_dim:3 96
        ; Layer.of_fn Tensor.relu_
        ; Layer.of_fn (Tensor.max_pool2d ~ceil_mode:true ~ksize:(3, 3) ~stride:(2, 2))
        ; fire (sub_vs 3) 96 16 64 64
        ; fire (sub_vs 4) 128 16 64 64
        ; fire (sub_vs 5) 128 32 128 128
        ; Layer.of_fn (Tensor.max_pool2d ~ceil_mode:true ~ksize:(3, 3) ~stride:(2, 2))
        ; fire (sub_vs 7) 256 32 128 128
        ; fire (sub_vs 8) 256 48 192 192
        ; fire (sub_vs 9) 384 48 192 192
        ; fire (sub_vs 10) 384 64 256 256
        ; Layer.of_fn (Tensor.max_pool2d ~ceil_mode:true ~ksize:(3, 3) ~stride:(2, 2))
        ; fire (sub_vs 12) 512 64 256 256
        ]
    | `v1_1 ->
      Layer.sequential
        [ Layer.conv2d_ (sub_vs 0) ~ksize:3 ~stride:2 ~input_dim:3 64
        ; Layer.of_fn Tensor.relu_
        ; Layer.of_fn (Tensor.max_pool2d ~ceil_mode:true ~ksize:(3, 3) ~stride:(2, 2))
        ; fire (sub_vs 3) 64 16 64 64
        ; fire (sub_vs 4) 128 16 64 64
        ; Layer.of_fn (Tensor.max_pool2d ~ceil_mode:true ~ksize:(3, 3) ~stride:(2, 2))
        ; fire (sub_vs 6) 128 32 128 128
        ; fire (sub_vs 7) 256 32 128 128
        ; Layer.of_fn (Tensor.max_pool2d ~ceil_mode:true ~ksize:(3, 3) ~stride:(2, 2))
        ; fire (sub_vs 9) 256 48 192 192
        ; fire (sub_vs 10) 384 48 192 192
        ; fire (sub_vs 11) 384 64 256 256
        ; fire (sub_vs 12) 512 64 256 256
        ]
  in
  let final_conv =
    Layer.conv2d_
      Var_store.(vs / "classifier" / "1")
      ~ksize:1
      ~stride:1
      ~input_dim:512
      num_classes
  in
  Layer.of_fn_ (fun xs ~is_training ->
    let batch_size = Tensor.shape xs |> List.hd_exn in
    Layer.forward features xs
    |> Tensor.dropout ~p:0.5 ~is_training
    |> Layer.forward final_conv
    |> Tensor.relu_
    |> Tensor.adaptive_avg_pool2d ~output_size:[ 1; 1 ]
    |> Tensor.view ~size:[ batch_size; num_classes ])
;;

let squeezenet1_0 vs ~num_classes = squeezenet vs ~version:`v1_0 ~num_classes
let squeezenet1_1 vs ~num_classes = squeezenet vs ~version:`v1_1 ~num_classes
OCaml

Innovation. Community. Security.