Skip to content

Potential BUG: HyperNEAT class #32

@mg10011

Description

@mg10011
Contributor

The HyperNEAT class has a bug I think. The setup() method reads as follows:

`class HyperNEAT(BaseAlgorithm):
def init(
self.pop_size = neat.pop_size

def setup(self, state=State()):
    state = self.neat.setup(state)
    state = self.substrate.setup(state)
    return self.hyper_genome.setup(state)`

The variables in the function definition def setup(self, state=State()): should be def setup(self, state: State):. This way the state is not automatically always a new State() object.

Thoughts?

Activity

WLS2002

WLS2002 commented on May 11, 2025

@WLS2002
Collaborator

Thanks for the suggestion!

The HyperNEAT class (like many other classes in TensorNEAT inherits from StatefulBaseClass, which defines the setup(self, state=State()) method. The use of state=State() in the setup() method is intentional. The goal is to support both use cases:

  1. If the user does not provide a state, StatefulBaseClass creates a new State() object. The algorithm then stores its internal setup data in this state and returns a new state object with the updates.
  2. If the user does provide a state, the algorithm stores its setup information in that state, again returning a new (modified) state object.

So I believe this design is reasonable, and I'm happy to continue the discussion if needed.

mg10011

mg10011 commented on May 11, 2025

@mg10011
ContributorAuthor

That makes sense. My issue was that I kept getting an error as when the HyperNEAT.setup() called the NEAT.setup(), using a RecurrentGenome. For some reason, the state became reinitialized, so my monkey patch was to rewrite NEAT.setup(). However, I see your reasons for the design choice. It seems sensible.

My monkey patch:

# MONKEY PATCH for the NEAT.setup() method--------------------------------------
def new_neat_setup_method(self_neat, state_input): 
    # 1. Preserve the randkey from the input state (which was set by 
    # Pipeline.setup).  This key is essential for splitting.
    key_to_split_for_population_and_next_state = state_input.randkey

    # 2. Call self.genome.setup().
    # This is the potentially problematic step if it mishandles 'randkey'.
    # interim_state will have genome-specific parameters. Whether 'randkey' is 
    # still in interim_state depends on self.genome.setup()'s behavior.
    interim_state = self_neat.genome.setup(state_input)

    # 3. Split the preserved key.
    k1_for_population_init, randkey_for_next_generation = jax.random.split(
        key_to_split_for_population_and_next_state, 2
        )

    # 4. Initialize the population.
    # self.genome.initialize needs the context from interim_state (genome 
    # params).
    per_genome_initialize_keys = \
        jax.random.split(k1_for_population_init, self_neat.pop_size)
    pop_nodes, pop_conns = vmap(self_neat.genome.initialize, in_axes=(None, 0))(
        interim_state,
        per_genome_initialize_keys
        )

    # 5. Register new attributes specific to NEAT's population setup.
    # These (pop_nodes, pop_conns, generation) are being introduced here.
    interim_state = interim_state.register(
        pop_nodes=pop_nodes,
        pop_conns=pop_conns,
        generation=jnp.float32(0),
        )

    # 6. Call species_controller.setup().
    # This will further modify the state by registering/updating its own 
    # attributes.
    interim_state = self_neat.species_controller.setup(
        interim_state, pop_nodes[0], pop_conns[0]
        )

    # 7. Set the final 'randkey' for the next generation.
    # The 'randkey' attribute was originally present in state_input.
    # - If self.genome.setup() and self.species_controller.setup() correctly
    #   preserved 'randkey' (by only using .register for their *new* 
    #   attributes), then 'randkey' will still be in interim_state, and we 
    #   should .update() it.
    # - If self.genome.setup() (the buggy part) *removed* 'randkey', then
    #   'randkey' will NOT be in interim_state, and we must .register() it.
    # The original NEAT.setup() uses .update(), implying 'randkey' is expected 
    # to exist.
    if 'randkey' in interim_state.state_dict: 
        final_state = interim_state.update(randkey=randkey_for_next_generation)
    else:
        # This path would be taken if genome.setup (or species_controller.setup)
        # incorrectly removed the randkey.
        final_state = \
            interim_state.register(randkey=randkey_for_next_generation)
    
    return final_state

NEAT.setup = new_neat_setup_method
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @mg10011@WLS2002

        Issue actions

          Potential BUG: HyperNEAT class · Issue #32 · EMI-Group/tensorneat