package torch
PyTorch bindings for OCaml
Install
Dune Dependency
Authors
Maintainers
Sources
0.8.tar.gz
md5=7f9cb5aa0d5e7e9700dde447a1f61c18
sha512=f4f4c23b5ba49cefa6e7f6d51ac1d015e3f6be284a80ceff378a0cd029faaca6026ddea72b8d97e718f7dc37b0879f816b2c789b809939df6881955f155c592f
doc/src/torch.vision/vgg.ml.html
Source file vgg.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
open Base open Torch let relu = Layer.of_fn_ (fun xs ~is_training:_ -> Tensor.relu xs) let relu_ = Layer.of_fn_ (fun xs ~is_training:_ -> Tensor.relu_ xs) type t = | C of int (* conv2d *) | M (* maxpool2d *) let layers_cfg = function | `A -> [ C 64; M; C 128; M; C 256; C 256; M; C 512; C 512; M; C 512; C 512; M ] | `B -> [ C 64; C 64; M; C 128; C 128; M; C 256; C 256; M; C 512; C 512; M; C 512; C 512; M ] | `D -> [ C 64 ; C 64 ; M ; C 128 ; C 128 ; M ; C 256 ; C 256 ; C 256 ; M ; C 512 ; C 512 ; C 512 ; M ; C 512 ; C 512 ; C 512 ; M ] | `E -> [ C 64 ; C 64 ; M ; C 128 ; C 128 ; M ; C 256 ; C 256 ; C 256 ; C 256 ; M ; C 512 ; C 512 ; C 512 ; C 512 ; M ; C 512 ; C 512 ; C 512 ; C 512 ; M ] let make_layers vs cfg ~batch_norm ~in_place_relu = let relu = if in_place_relu then relu_ else relu in let sub_vs index = Var_store.sub vs (Int.to_string index) in let (_output_dim, _output_idx), layers = List.fold_map (layers_cfg cfg) ~init:(3, 0) ~f:(fun (input_dim, idx) v -> match v with | M -> ( (input_dim, idx + 1) , [ Layer.of_fn (Tensor.max_pool2d ~ksize:(2, 2)) |> Layer.with_training ] ) | C output_dim -> let conv2d = Layer.conv2d_ (sub_vs idx) ~ksize:3 ~stride:1 ~padding:1 ~input_dim output_dim |> Layer.with_training in if batch_norm then ( let batch_norm = Layer.batch_norm2d (sub_vs (idx + 1)) output_dim in (output_dim, idx + 3), [ conv2d; batch_norm; relu ]) else (output_dim, idx + 2), [ conv2d; relu ]) in List.concat layers let vgg ~num_classes vs cfg ~batch_norm = let cls_vs i = Var_store.(vs / "classifier" / Int.to_string i) in let layers = make_layers (Var_store.sub vs "features") cfg ~batch_norm ~in_place_relu:true |> Layer.sequential_ in let fc1 = Layer.linear (cls_vs 0) ~input_dim:(512 * 7 * 7) 4096 in let fc2 = Layer.linear (cls_vs 3) ~input_dim:4096 4096 in let fc3 = Layer.linear (cls_vs 6) ~input_dim:4096 num_classes in Layer.of_fn_ (fun xs ~is_training -> let batch_size = Tensor.shape xs |> List.hd_exn in Layer.forward_ layers xs ~is_training |> Tensor.view ~size:[ batch_size; -1 ] |> Layer.forward fc1 |> Tensor.relu |> Tensor.dropout ~p:0.5 ~is_training |> Layer.forward fc2 |> Tensor.relu |> Tensor.dropout ~p:0.5 ~is_training |> Layer.forward fc3) let vgg11 vs ~num_classes = vgg ~num_classes vs `A ~batch_norm:false let vgg11_bn vs ~num_classes = vgg ~num_classes vs `A ~batch_norm:true let vgg13 vs ~num_classes = vgg ~num_classes vs `B ~batch_norm:false let vgg13_bn vs ~num_classes = vgg ~num_classes vs `B ~batch_norm:true let vgg16 vs ~num_classes = vgg ~num_classes vs `D ~batch_norm:false let vgg16_bn vs ~num_classes = vgg ~num_classes vs `D ~batch_norm:true let vgg19 vs ~num_classes = vgg ~num_classes vs `E ~batch_norm:false let vgg19_bn vs ~num_classes = vgg ~num_classes vs `E ~batch_norm:true let vgg16_layers ?(max_layer = Int.max_value) vs ~batch_norm = let layers = List.take (make_layers (Var_store.sub vs "features") `D ~batch_norm ~in_place_relu:false) max_layer in (* [Staged.stage] just indicates that the [vs] and [~indexes] parameters should only be applied on the first call to this function. *) Staged.stage (fun xs -> List.fold_mapi layers ~init:xs ~f:(fun i xs layer -> let xs = Layer.forward_ layer xs ~is_training:false in xs, (i, xs)) |> fun (_, indexed_layers) -> Map.of_alist_exn (module Int) indexed_layers)
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>