1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
//! Provides configuration of weights and their initialization.

use crate::capnp_util::*;
use crate::co::{ITensorDesc, SharedTensor};
use crate::juice_capnp::weight_config as capnp_config;
use crate::util::native_backend;
use rand::{self, prelude::*};

#[derive(Debug, Clone)]
/// Specifies training configuration for a weight blob.
pub struct WeightConfig {
    /// The name of the weight blob -- useful for sharing weights among
    /// layers, but never required otherwise. To share a weight between two
    /// layers, give it a (non-empty) name.
    ///
    /// Default: ""
    pub name: String,
    /// Whether to require shared weights to have the same shape, or just the same
    /// count
    ///
    /// Default: DimCheckMode::Strict
    pub share_mode: DimCheckMode,

    /// The multiplier on the global learning rate for this parameter.
    ///
    /// Default: 1.0f32
    pub lr_mult: Option<f32>,

    /// The multiplier on the global weight decay for this parameter.
    ///
    /// Default: 1.0f32
    pub decay_mult: Option<f32>,

    /// The filler that initializes the weights in the weight blob.
    ///
    /// Default: None
    pub filler: Option<FillerType>,
}

impl Default for WeightConfig {
    fn default() -> WeightConfig {
        WeightConfig {
            name: "".to_owned(),
            share_mode: DimCheckMode::Strict,
            lr_mult: None,
            decay_mult: None,
            filler: None,
        }
    }
}

impl WeightConfig {
    /// Checks dimensions of two blobs according to the `share_mode`.
    /// Returns an error if there is a count/shape mismatch.
    pub fn check_dimensions<T>(
        &self,
        tensor_one: &SharedTensor<T>,
        tensor_two: &SharedTensor<T>,
        param_name: String,
        owner_name: String,
        layer_name: String,
    ) -> Result<(), String> {
        match self.share_mode {
            // Permissive dimension checking -- only check counts are the same.
            DimCheckMode::Permissive => {
                if tensor_one.desc().size() != tensor_two.desc().size() {
                    return Err(format!(
                        "Cannot share weight '{}' owned by layer '{}' with layer '{}';
                                count mismatch.
                                Owner layer weight shape is {:?};
                                Sharing layer weight shape is {:?}",
                        param_name,
                        owner_name,
                        layer_name,
                        tensor_two.desc(),
                        tensor_one.desc()
                    ));
                }
            }
            // Strict dimension checking -- all dims must be the same.
            DimCheckMode::Strict => {
                if tensor_one.desc() != tensor_two.desc() {
                    return Err(format!(
                        "Cannot share weight '{}' owned by layer '{}' with layer '{}';
                                shape mismatch.
                                Owner layer weight shape is {:?};
                                Sharing layer expects weight shape {:?}",
                        param_name,
                        owner_name,
                        layer_name,
                        tensor_two.desc(),
                        tensor_one.desc()
                    ));
                }
            }
        }
        Ok(())
    }

    /// The multiplier on the global learning rate for this weight blob.
    pub fn lr_mult(&self) -> f32 {
        match self.lr_mult {
            Some(val) => val,
            None => 1.0f32,
        }
    }

    /// The multiplier on the global weight decay for this weight blob.
    pub fn decay_mult(&self) -> f32 {
        match self.decay_mult {
            Some(val) => val,
            None => 1.0f32,
        }
    }
}

impl<'a> CapnpWrite<'a> for WeightConfig {
    type Builder = capnp_config::Builder<'a>;

    /// Write the WeightConfig into a capnp message.
    fn write_capnp(&self, builder: &mut Self::Builder) {
        // TODO: incomplete since WeightConfig isn't really used internally in Juice at the moment.
        builder.reborrow().set_name(&self.name);
    }
}

impl<'a> CapnpRead<'a> for WeightConfig {
    type Reader = capnp_config::Reader<'a>;

    fn read_capnp(reader: Self::Reader) -> Self {
        // TODO: incomplete since WeightConfig isn't really used internally in Juice at the moment.
        let name = reader.get_name().unwrap().to_owned();
        WeightConfig {
            name: name,
            ..Self::default()
        }
    }
}

#[derive(Debug, Copy, Clone)]
/// Enum for specifing the shared weights behaviour
pub enum DimCheckMode {
    /// Strict requires that shapes match.
    Strict,
    /// Permissive requires only the count of weights to match.
    Permissive,
}

#[derive(Debug, Copy, Clone)]
/// Enum for specifing the type of Filler.
pub enum FillerType {
    /// Fills the weight blob with a constant `value` (all values are the same).
    Constant {
        /// The value that will be used to fill the blob.
        value: f32,
    },
    /// Fills the weight blobs based on the paper:
    ///
    /// `[Bengio and Glorot 2010]: Understanding the difficulty of training deep feedforward neural networks.`
    ///
    /// Also known as Xavier filler.
    Glorot {
        /// Number of input nodes for each output.
        input_size: usize,
        /// Number of output nodes for each input.
        output_size: usize,
    },
}

impl FillerType {
    /// Uses a filler as specified by this FillerType to fill the values in a SharedTensor
    ///
    /// This filling of weights is usually done directly after creation of the weight blob.
    pub fn fill(&self, weight: &mut SharedTensor<f32>) {
        let native = native_backend();
        let native_device = native.device();

        match *self {
            FillerType::Constant { value } => Self::fill_constant(weight, value),
            FillerType::Glorot {
                input_size,
                output_size,
            } => Self::fill_glorot(weight, input_size, output_size),
        }
    }

    /// Directly use the [Constant Filler](#variant.Constant).
    pub fn fill_constant(weight: &mut SharedTensor<f32>, value: f32) {
        let native = native_backend();
        let native_weight = weight.write_only(native.device()).unwrap();

        for e in native_weight.as_mut_slice::<f32>() {
            *e = value;
        }
    }

    /// Directly use the [Glorot Filler](#variant.Glorot).
    pub fn fill_glorot(weight: &mut SharedTensor<f32>, num_inputs: usize, num_outputs: usize) {
        let native = native_backend();
        let native_weight = weight.write_only(native.device()).unwrap();
        let init_range = (6.0f32 / (num_inputs as f32 + num_outputs as f32)).sqrt();

        let mut rng = thread_rng();
        let between = rand::distributions::Uniform::from(-init_range..=init_range);
        for e in native_weight.as_mut_slice::<f32>() {
            *e = between.sample(&mut rng);
        }
    }
}