package torch

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file alexnet.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
(* AlexNet model.
   https://arxiv.org/abs/1404.5997
*)
open Base
open Torch

let sub = Var_store.sub
let conv2d = Layer.conv2d_

let features vs =
  let conv1 = conv2d (sub vs "0") ~ksize:11 ~padding:2 ~stride:4 ~input_dim:3 64 in
  let conv2 = conv2d (sub vs "3") ~ksize:5 ~padding:1 ~stride:2 ~input_dim:64 192 in
  let conv3 = conv2d (sub vs "6") ~ksize:3 ~padding:1 ~stride:1 ~input_dim:192 384 in
  let conv4 = conv2d (sub vs "8") ~ksize:3 ~padding:1 ~stride:1 ~input_dim:384 256 in
  let conv5 = conv2d (sub vs "10") ~ksize:3 ~padding:1 ~stride:1 ~input_dim:256 256 in
  Layer.of_fn (fun xs ->
    Layer.forward conv1 xs
    |> Tensor.relu
    |> Tensor.max_pool2d ~ksize:(3, 3) ~stride:(2, 2)
    |> Layer.forward conv2
    |> Tensor.relu
    |> Tensor.max_pool2d ~ksize:(3, 3) ~stride:(2, 2)
    |> Layer.forward conv3
    |> Tensor.relu
    |> Layer.forward conv4
    |> Tensor.relu
    |> Layer.forward conv5
    |> Tensor.relu
    |> Tensor.max_pool2d ~ksize:(3, 3) ~stride:(2, 2))
;;

let classifier ?num_classes vs =
  let linear1 = Layer.linear (sub vs "1") ~input_dim:(256 * 6 * 6) 4096 in
  let linear2 = Layer.linear (sub vs "4") ~input_dim:4096 4096 in
  let linear_or_id =
    match num_classes with
    | Some num_classes -> Layer.linear (sub vs "6") ~input_dim:4096 num_classes
    | None -> Layer.id
  in
  Layer.of_fn_ (fun xs ~is_training ->
    Tensor.dropout xs ~p:0.5 ~is_training
    |> Layer.forward linear1
    |> Tensor.relu
    |> Tensor.dropout ~p:0.5 ~is_training
    |> Layer.forward linear2
    |> Tensor.relu
    |> Layer.forward linear_or_id)
;;

let alexnet ?num_classes vs =
  let features = features (sub vs "features") in
  let classifier = classifier ?num_classes (sub vs "classifier") in
  Layer.of_fn_ (fun xs ~is_training ->
    let batch_size = Tensor.shape xs |> List.hd_exn in
    Layer.forward features xs
    |> Tensor.adaptive_avg_pool2d ~output_size:[ 6; 6 ]
    |> Tensor.view ~size:[ batch_size; -1 ]
    |> Layer.forward_ classifier ~is_training)
;;
OCaml

Innovation. Community. Security.