BINN

This is the API reference for the BINN-package. For usage examples, see Examples. Note that the API is still stabilizing and will undergo changes.

BINN

Bases: LightningModule

Implements a Biologically Informed Neural Network (BINN). The BINN is implemented using the Lightning-framework. If you are unfamiliar with PyTorch, we suggest visiting their website: https://pytorch.org/

Parameters:

Name Type Description Default
pathways Network

A Network object that defines the network topology.

required
activation str

Activation function to use. Defaults to "tanh".

'tanh'
weight Tensor

Weights for loss function. Defaults to torch.Tensor([1, 1]).

tensor([1, 1])
learning_rate float

Learning rate for optimizer. Defaults to 1e-4.

0.0001
n_layers int

Number of layers in the network. Defaults to 4.

4
scheduler str

Learning rate scheduler to use. Defaults to "plateau".

'plateau'
optimizer str

Optimizer to use. Defaults to "adam".

'adam'
validate bool

Whether to use validation data during training. Defaults to False.

False
n_outputs int

Number of output nodes. Defaults to 2.

2
dropout float

Dropout probability. Defaults to 0.

0
residual bool

Whether to use residual connections. Defaults to False.

False

Attributes:

Name Type Description
residual bool

Whether to use residual connections.

pathways Network

A Network object that defines the network topology.

n_layers int

Number of layers in the network.

layer_names List[str]

List of layer names.

features Index

A pandas Index object containing the input features.

layers Module

The layers of the BINN.

loss Module

The loss function used during training.

learning_rate float

Learning rate for optimizer.

scheduler str

Learning rate scheduler used.

optimizer str

Optimizer used.

validate bool

Whether to use validation data during training.

Source code in binn/binn.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
class BINN(pl.LightningModule):
    """
    Implements a Biologically Informed Neural Network (BINN). The BINN
    is implemented using the Lightning-framework.
    If you are unfamiliar with PyTorch, we suggest visiting
    their website: https://pytorch.org/


    Args:
        pathways (Network): A Network object that defines the network topology.
        activation (str, optional): Activation function to use. Defaults to "tanh".
        weight (torch.Tensor, optional): Weights for loss function. Defaults to torch.Tensor([1, 1]).
        learning_rate (float, optional): Learning rate for optimizer. Defaults to 1e-4.
        n_layers (int, optional): Number of layers in the network. Defaults to 4.
        scheduler (str, optional): Learning rate scheduler to use. Defaults to "plateau".
        optimizer (str, optional): Optimizer to use. Defaults to "adam".
        validate (bool, optional): Whether to use validation data during training. Defaults to False.
        n_outputs (int, optional): Number of output nodes. Defaults to 2.
        dropout (float, optional): Dropout probability. Defaults to 0.
        residual (bool, optional): Whether to use residual connections. Defaults to False.

    Attributes:
        residual (bool): Whether to use residual connections.
        pathways (Network): A Network object that defines the network topology.
        n_layers (int): Number of layers in the network.
        layer_names (List[str]): List of layer names.
        features (Index): A pandas Index object containing the input features.
        layers (nn.Module): The layers of the BINN.
        loss (nn.Module): The loss function used during training.
        learning_rate (float): Learning rate for optimizer.
        scheduler (str): Learning rate scheduler used.
        optimizer (str): Optimizer used.
        validate (bool): Whether to use validation data during training.
    """

    def __init__(
        self,
        network: Network = None,
        connectivity_matrices: list = None,
        activation: str = "tanh",
        weight: torch.tensor = torch.tensor([1, 1]),
        learning_rate: float = 1e-4,
        n_layers: int = 4,
        scheduler: str = "plateau",
        optimizer: str = "adam",
        validate: bool = False,
        n_outputs: int = 2,
        dropout: float = 0,
        residual: bool = False,
        device: str = "cpu",
    ):
        super().__init__()
        self.to(device)
        self.residual = residual
        if not connectivity_matrices:
            self.network = network
            self.connectivity_matrices = self.network.get_connectivity_matrices(
                n_layers
            )
        else:
            self.connectivity_matrices = connectivity_matrices
        self.n_layers = n_layers

        layer_sizes = []
        self.layer_names = []

        matrix = self.connectivity_matrices[0]
        i, _ = matrix.shape
        layer_sizes.append(i)
        self.layer_names.append(matrix.index.tolist())
        self.features = matrix.index
        self.trainable_params = matrix.to_numpy().sum() + len(matrix.index)
        for matrix in self.connectivity_matrices[1:]:
            self.trainable_params += matrix.to_numpy().sum() + len(matrix.index)
            i, _ = matrix.shape
            layer_sizes.append(i)
            self.layer_names.append(matrix.index.tolist())

        if self.residual:
            self.layers = _generate_residual(
                layer_sizes,
                connectivity_matrices=self.connectivity_matrices,
                activation=activation,
                bias=True,
                n_outputs=n_outputs,
            )
        else:
            self.layers = _generate_sequential(
                layer_sizes,
                connectivity_matrices=self.connectivity_matrices,
                activation=activation,
                bias=True,
                n_outputs=n_outputs,
                dropout=dropout,
            )
        self.apply(_init_weights)
        self.weight = weight
        self.loss = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.scheduler = scheduler
        self.optimizer = optimizer
        self.validate = validate
        self.save_hyperparameters()
        print("\nBINN is on the device:", self.device, end="\n")

    def forward(self, x: torch.tensor) -> torch.tensor:
        """
        Performs a forward pass through the BINN.

        Args:
            x (torch.Tensor): The input tensor to the BINN.

        Returns:
            torch.Tensor: The output tensor of the BINN.
        """
        if self.residual:
            return self._forward_residual(x)
        else:
            return self.layers(x)

    def training_step(self, batch, _):
        """
        Performs a single training step for the BINN.

        Args:
            batch: The batch of data to use for the training step.
            _: Not used.

        Returns:
            torch.Tensor: The loss tensor for the training step.
        """
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        y_hat = self(x).to(self.device)
        loss = self.loss(y_hat, y)
        prediction = torch.argmax(y_hat, dim=1)
        accuracy = self.calculate_accuracy(y, prediction)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("train_acc", accuracy, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, _):
        """
        Implements a single validation step for the BINN.

        Args:
            batch: A tuple containing the input and output data for the current batch.
            _: The batch index, which is not used.
        """

        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        prediction = torch.argmax(y_hat, dim=1)
        accuracy = self.calculate_accuracy(y, prediction)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("val_acc", accuracy, prog_bar=True, on_step=False, on_epoch=True)
        return {"val_loss": loss, "val_acc": accuracy}

    def test_step(self, batch, _):
        """
        Implements a single testing step for the BINN.

        Args:
            batch: A tuple containing the input and output data for the current batch.
            _: The batch index, which is not used.
        """
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        prediction = torch.argmax(y_hat, dim=1)
        accuracy = self.calculate_accuracy(y, prediction)
        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("test_acc", accuracy, prog_bar=True, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler for training the BINN.

        Returns:
            A list of optimizers and a list of learning rate schedulers.
        """
        if self.validate:
            monitor = "val_loss"
        else:
            monitor = "train_loss"

        if isinstance(self.optimizer, str):
            if self.optimizer == "adam":
                optimizer = torch.optim.Adam(
                    self.parameters(), lr=self.learning_rate, weight_decay=1e-3
                )
                self.optimizer = optimizer
        else:
            optimizer = self.optimizer

        if self.scheduler == "plateau":
            scheduler = {
                "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, patience=5, threshold=0.01, mode="min", verbose=True
                ),
                "interval": "epoch",
                "monitor": monitor,
            }
        elif self.scheduler == "step":
            scheduler = {
                "scheduler": torch.optim.lr_scheduler.StepLR(
                    optimizer, step_size=25, gamma=0.1, verbose=True
                )
            }

        return [optimizer], [scheduler]

    def calculate_accuracy(self, y, prediction) -> float:
        return torch.sum(y == prediction).item() / float(len(y))

    def get_connectivity_matrices(self) -> list:
        """
        Returns the connectivity matrices underlying the BINN.

        Returns:
            The connectivity matrices as a list of Pandas DataFrames.
        """
        return self.connectivity_matrices

    def reset_params(self):
        """
        Resets the trainable parameters of the BINN.
        """
        self.apply(_reset_params)

    def init_weights(self):
        """
        Initializes the trainable parameters of the BINN.
        """
        self.apply(_init_weights)

    def _forward_residual(self, x: torch.tensor):
        x_final = torch.tensor([0, 0], device=self.device)
        residual_counter: int = 0
        for name, layer in self.layers.named_children():
            if name.startswith("Residual"):
                if "out" in name:
                    x_temp = layer(x)
                if _is_activation(layer):
                    x_temp = layer(x_temp)
                    x_final = x_temp + x_final
                    residual_counter = residual_counter + 1
            else:
                x = layer(x)
        x_final = x_final / residual_counter
        return x_final

configure_optimizers()

Configures the optimizer and learning rate scheduler for training the BINN.

Returns:

Type Description

A list of optimizers and a list of learning rate schedulers.

Source code in binn/binn.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def configure_optimizers(self):
    """
    Configures the optimizer and learning rate scheduler for training the BINN.

    Returns:
        A list of optimizers and a list of learning rate schedulers.
    """
    if self.validate:
        monitor = "val_loss"
    else:
        monitor = "train_loss"

    if isinstance(self.optimizer, str):
        if self.optimizer == "adam":
            optimizer = torch.optim.Adam(
                self.parameters(), lr=self.learning_rate, weight_decay=1e-3
            )
            self.optimizer = optimizer
    else:
        optimizer = self.optimizer

    if self.scheduler == "plateau":
        scheduler = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, patience=5, threshold=0.01, mode="min", verbose=True
            ),
            "interval": "epoch",
            "monitor": monitor,
        }
    elif self.scheduler == "step":
        scheduler = {
            "scheduler": torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=25, gamma=0.1, verbose=True
            )
        }

    return [optimizer], [scheduler]

forward(x)

Performs a forward pass through the BINN.

Parameters:

Name Type Description Default
x Tensor

The input tensor to the BINN.

required

Returns:

Type Description
tensor

torch.Tensor: The output tensor of the BINN.

Source code in binn/binn.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def forward(self, x: torch.tensor) -> torch.tensor:
    """
    Performs a forward pass through the BINN.

    Args:
        x (torch.Tensor): The input tensor to the BINN.

    Returns:
        torch.Tensor: The output tensor of the BINN.
    """
    if self.residual:
        return self._forward_residual(x)
    else:
        return self.layers(x)

get_connectivity_matrices()

Returns the connectivity matrices underlying the BINN.

Returns:

Type Description
list

The connectivity matrices as a list of Pandas DataFrames.

Source code in binn/binn.py
228
229
230
231
232
233
234
235
def get_connectivity_matrices(self) -> list:
    """
    Returns the connectivity matrices underlying the BINN.

    Returns:
        The connectivity matrices as a list of Pandas DataFrames.
    """
    return self.connectivity_matrices

init_weights()

Initializes the trainable parameters of the BINN.

Source code in binn/binn.py
243
244
245
246
247
def init_weights(self):
    """
    Initializes the trainable parameters of the BINN.
    """
    self.apply(_init_weights)

reset_params()

Resets the trainable parameters of the BINN.

Source code in binn/binn.py
237
238
239
240
241
def reset_params(self):
    """
    Resets the trainable parameters of the BINN.
    """
    self.apply(_reset_params)

test_step(batch, _)

Implements a single testing step for the BINN.

Parameters:

Name Type Description Default
batch

A tuple containing the input and output data for the current batch.

required
_

The batch index, which is not used.

required
Source code in binn/binn.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def test_step(self, batch, _):
    """
    Implements a single testing step for the BINN.

    Args:
        batch: A tuple containing the input and output data for the current batch.
        _: The batch index, which is not used.
    """
    x, y = batch
    y_hat = self(x)
    loss = self.loss(y_hat, y)
    prediction = torch.argmax(y_hat, dim=1)
    accuracy = self.calculate_accuracy(y, prediction)
    self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    self.log("test_acc", accuracy, prog_bar=True, on_step=False, on_epoch=True)

training_step(batch, _)

Performs a single training step for the BINN.

Parameters:

Name Type Description Default
batch

The batch of data to use for the training step.

required
_

Not used.

required

Returns:

Type Description

torch.Tensor: The loss tensor for the training step.

Source code in binn/binn.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def training_step(self, batch, _):
    """
    Performs a single training step for the BINN.

    Args:
        batch: The batch of data to use for the training step.
        _: Not used.

    Returns:
        torch.Tensor: The loss tensor for the training step.
    """
    x, y = batch
    x = x.to(self.device)
    y = y.to(self.device)
    y_hat = self(x).to(self.device)
    loss = self.loss(y_hat, y)
    prediction = torch.argmax(y_hat, dim=1)
    accuracy = self.calculate_accuracy(y, prediction)
    self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    self.log("train_acc", accuracy, prog_bar=True, on_step=False, on_epoch=True)
    return loss

validation_step(batch, _)

Implements a single validation step for the BINN.

Parameters:

Name Type Description Default
batch

A tuple containing the input and output data for the current batch.

required
_

The batch index, which is not used.

required
Source code in binn/binn.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def validation_step(self, batch, _):
    """
    Implements a single validation step for the BINN.

    Args:
        batch: A tuple containing the input and output data for the current batch.
        _: The batch index, which is not used.
    """

    x, y = batch
    y_hat = self(x)
    loss = self.loss(y_hat, y)
    prediction = torch.argmax(y_hat, dim=1)
    accuracy = self.calculate_accuracy(y, prediction)
    self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    self.log("val_acc", accuracy, prog_bar=True, on_step=False, on_epoch=True)
    return {"val_loss": loss, "val_acc": accuracy}