Skip to content

Alarming RAM Usage Spike during Hyperband Search #1031

@LorenzoMonti

Description

@LorenzoMonti

While training a model using a Hyperband Search tuner on an HPC system with ~500GB of RAM, the program continuously increases its RAM usage without releasing memory between trials similarly to what is described in #873. Despite the high system capacity, the RAM usage eventually exceeds available resources, leading to the program or the entire system crashing.

This is the relevant code:

    def build_model(self, hp):
        # Tunable hyperparameters
        num_transformer_blocks = hp.Int(
            'num_transformer_blocks', 
            min_value=2, 
            max_value=8, 
            step=2
        )
        
        head_size = hp.Int(
            'head_size', 
            min_value=32, 
            max_value=128, 
            step=32
        )
        
        num_heads = hp.Int(
            'num_heads', 
            min_value=2, 
            max_value=8, 
            step=2
        )
        
        ff_dim = hp.Int(
            'ff_dim', 
            min_value=64, 
            max_value=256, 
            step=64
        )
        
        learning_rate = hp.Float(
            'learning_rate', 
            min_value=1e-4, 
            max_value=1e-2, 
            sampling='LOG'
        )
        
        dropout_rate = hp.Float(
            'dropout_rate', 
            min_value=0.1, 
            max_value=0.5, 
            step=0.1
        )
        
        sparsity_rate = hp.Float(
            'sparsity_rate', 
            min_value=0.1, 
            max_value=0.5, 
            step=0.1
        )
        
        # Input layer
        inputs = tf.keras.layers.Input(shape=self.input_shape)
        
        # Positional Encoding
        positions = self._positional_encoding(
            self.input_shape[0], 
            self.input_shape[1]
        )
        positions = tf.expand_dims(positions, axis=0)
        
        x = tf.keras.layers.Add()([inputs, positions])
        
        for _ in range(num_transformer_blocks):
            x = self._informer_encoder(
                x, 
                head_size, 
                num_heads, 
                ff_dim, 
                dropout_rate,
                sparsity_rate
            )
        
        # Sequence Length Reduction
        x = tf.keras.layers.GlobalAveragePooling1D()(x)
        
        # MLP Layers
        x = tf.keras.layers.Dense(128, activation="gelu")(x)
        x = tf.keras.layers.Dropout(dropout_rate)(x)
        x = tf.keras.layers.Dense(64, activation="gelu")(x)
        x = tf.keras.layers.Dropout(dropout_rate)(x)
        
        # Output Layer
        outputs = tf.keras.layers.Dense(1)(x)
        
        # Create and compile model
        model = tf.keras.Model(inputs=inputs, outputs=outputs, name='informer_tuned')
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
            loss='mean_squared_error',
            metrics=['mae']
        )
        
        return model

    def _positional_encoding(self, length, d_model):
        ...

    def _informer_encoder(self, inputs, head_size, num_heads, ff_dim, dropout, sparsity_rate):
        ...

    def tune_hyperparameters(self, X_train, X_val, y_train, y_val):
        tuner = kt.Hyperband(
            self.build_model,
            objective='val_mae',
            max_epochs=500,
            factor=3,
            directory=self.output_directory,
            project_name='informer_tuning',
            executions_per_trial=self.executions_per_trial
        )
        
        # Early stopping
        stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=100)
        
        # Run hyperparameter search
        tuner.search(
            X_train, y_train,
            epochs=500,
            validation_data=(X_val, y_val),
            callbacks=[stop_early]
        )

At this point, I would like to know if there are any updates regarding a solution that does not rely on workarounds like calling clear_session() before every build_model.
Thanks in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions