package torch

  1. Overview
  2. Docs
PyTorch bindings for OCaml

Install

Dune Dependency

Authors

Maintainers

Sources

0.14.tar.gz
md5=7a712ae0e8c7f5452f628377d80a5bb4
sha512=22314b655bc6b5e5c970cbab8d132eae36ee0b8fb0a96b63727899442eb70fe00bd1895d7cc718a85b58bc2b2b4ea6820fa288a19346f095e5de18f7e47c2d02

doc/src/torch/var_store.ml.html

Source file var_store.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
open Base

module Tensor_id : sig
  include Hashable.Key

  val create : unit -> t
end = struct
  include Int

  let create =
    let current = ref 0 in
    fun () ->
      Int.incr current;
      !current
end

(* Maybe we should also store the full path in the var stores ? *)
type t =
  { name : string
  ; trainable_tensors : (Tensor_id.t, Tensor.t) Hashtbl.t
  ; all_tensors_by_name : (string, Tensor.t) Hashtbl.t
  ; subs : (string, t) Hashtbl.t
  ; device : Device.t
  ; mutable frozen : bool
  }

let create ?(frozen = false) ?(device = Device.Cpu) ~name () =
  { name
  ; trainable_tensors = Hashtbl.create (module Tensor_id)
  ; subs = Hashtbl.create (module String)
  ; all_tensors_by_name = Hashtbl.create (module String)
  ; device
  ; frozen
  }

let first_free_name name table =
  if Hashtbl.mem table name
  then (
    let rec loop idx =
      let name = Printf.sprintf "%s_%d" name idx in
      if Hashtbl.mem table name then loop (idx + 1) else name
    in
    loop 1)
  else name

let sub t sub_name =
  if String.contains sub_name '.'
  then Printf.failwithf "sub names cannot contain ., %s" sub_name ();
  Hashtbl.find_or_add t.subs sub_name ~default:(fun () ->
      { name = t.name
      ; trainable_tensors = Hashtbl.create (module Tensor_id)
      ; subs = Hashtbl.create (module String)
      ; all_tensors_by_name = Hashtbl.create (module String)
      ; device = t.device
      ; frozen = t.frozen
      })

let subi t i = sub t (Int.to_string i)
let ( / ) = sub
let ( // ) = subi

let rec freeze t =
  t.frozen <- true;
  Hashtbl.iter t.trainable_tensors ~f:(fun tensor ->
      ignore (Tensor.set_requires_grad tensor ~r:false : Tensor.t));
  Hashtbl.iter t.subs ~f:freeze

let rec unfreeze t =
  t.frozen <- false;
  Hashtbl.iter t.trainable_tensors ~f:(fun tensor ->
      ignore (Tensor.set_requires_grad tensor ~r:true : Tensor.t));
  Hashtbl.iter t.subs ~f:unfreeze

let rec num_trainable_vars t =
  let sub_vars =
    Hashtbl.data t.subs |> List.fold ~init:0 ~f:(fun acc t -> acc + num_trainable_vars t)
  in
  sub_vars + Hashtbl.length t.trainable_tensors

let iter_trainable_vars t ~f =
  let f ~key ~data = f key data in
  let rec loop t =
    Hashtbl.iter t.subs ~f:loop;
    Hashtbl.iteri t.trainable_tensors ~f
  in
  loop t

let all_vars t =
  let rec walk t ~path =
    let sub_vars =
      Hashtbl.to_alist t.subs
      |> List.concat_map ~f:(fun (key, t) -> walk t ~path:(key :: path))
    in
    let vars =
      Hashtbl.to_alist t.all_tensors_by_name
      |> List.map ~f:(fun (key, tensor) ->
             List.rev (key :: path) |> String.concat ~sep:".", tensor)
    in
    vars @ sub_vars
  in
  walk t ~path:[]

let copy ~src ~dst =
  Tensor.no_grad (fun () ->
      let rec walk ~src ~dst path =
        Hashtbl.iteri dst.all_tensors_by_name ~f:(fun ~key ~data ->
            match Hashtbl.find src.all_tensors_by_name key with
            | Some src -> Tensor.copy_ data ~src
            | None ->
              Printf.failwithf
                "cannot find var %s from var-store %s in %s"
                (List.rev (key :: path) |> String.concat ~sep:".")
                dst.name
                src.name
                ());
        Hashtbl.iteri dst.subs ~f:(fun ~key ~data:dst ->
            match Hashtbl.find src.subs key with
            | Some src -> walk ~src ~dst (key :: path)
            | None ->
              Printf.failwithf
                "cannot find sub %s from var-store %s in %s"
                (List.rev (key :: path) |> String.concat ~sep:".")
                dst.name
                src.name
                ())
      in
      walk ~src ~dst [])

let name t = t.name
let device t = t.device

module Init = struct
  type t =
    | Zeros
    | Ones
    | Const of float
    | Normal of
        { mean : float
        ; stdev : float
        }
    | Uniform of float * float
    | Copy of Tensor.t
end

let new_var ?(trainable = true) t ~shape ~init ~name =
  let device = device t in
  let requires_grad = trainable && not t.frozen in
  let tensor =
    match (init : Init.t) with
    | Zeros -> Tensor.zeros shape ~requires_grad ~device
    | Ones -> Tensor.ones shape ~requires_grad ~device
    | Const scale -> Tensor.ones shape ~requires_grad ~device ~scale
    | Normal { mean = 0.; stdev } ->
      Tensor.randn shape ~scale:stdev ~requires_grad ~device
    | Normal { mean; stdev } ->
      Tensor.( + )
        (Tensor.randn shape ~scale:stdev ~requires_grad ~device)
        (Tensor.f mean)
    | Uniform (from, to_) ->
      Tensor.zeros shape ~device
      |> Tensor.uniform_ ~from ~to_
      |> Tensor.set_requires_grad ~r:requires_grad
    | Copy src ->
      Tensor.copy src
      |> Tensor.to_device ~device
      |> Tensor.set_requires_grad ~r:requires_grad
  in
  if String.contains name '.'
  then Printf.failwithf "tensor names cannot contain ., %s" name ();
  let name = first_free_name name t.all_tensors_by_name in
  Hashtbl.add_exn t.all_tensors_by_name ~key:name ~data:tensor;
  if trainable
  then Hashtbl.add_exn t.trainable_tensors ~key:(Tensor_id.create ()) ~data:tensor;
  tensor

let new_var_copy ?trainable t ~src ~name =
  new_var ?trainable t ~shape:(Tensor.shape src) ~init:(Copy src) ~name
OCaml

Innovation. Community. Security.