package torch

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file rnn_intf.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
(** The module interface implemented by Recurrent Neural Networks. *)
module type S = sig
  type t
  type state

  (** [create vs ~input_dim ~hidden_size] creates a new RNN with the
      specified input dimension and hidden size.
  *)
  val create : Var_store.t -> input_dim:int -> hidden_size:int -> t

  (** [step t state input_] applies one step of the RNN on the
      given input using the specified state. The updated state is
      returned.
  *)
  val step : t -> state -> Tensor.t -> state

  (** [seq t inputs ~is_training] applies multiple steps of the RNN
      starting from a zero state. The hidden states and the final state
      are returned.
      [inputs] should have shape [batch_size * timesteps * input_dim],
      the returned output tensor then has shape
      [batch_size * timesteps * hidden_size].
  *)
  val seq : t -> Tensor.t -> is_training:bool -> Tensor.t * state

  (** [zero_state t ~batch_size] returns an initial state to be used for
      a RNN.
  *)
  val zero_state : t -> batch_size:int -> state
end
OCaml

Innovation. Community. Security.