Creating a BINN¶
This notebook demonstrates some examples on how a BINN can be created and trained.
The method begins with constructing a directed graph representing biological pathways, mapping input features (e.g., proteins or genes) to nodes within the network. This graph is manipulated to produce hierarchical layers and connectivity matrices, which guide the structure of the BINN.
If you want to create your own BINN from scratch, you need some input data (input_data
below) in the form of a pandas dataframe.
from binn import BINN
import pandas as pd
input_data = pd.read_csv("../binn/data/sample_datamatrix.csv")
binn = BINN(
data_matrix=input_data,
network_source="reactome",
input_source="uniprot",
n_layers=4,
dropout=0.2,
)
binn
/Users/erikhartman/dev/BINN/test-venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
[INFO] BINN is on device: cpu
BINN( (layers): Sequential( (Layer_0): Linear(in_features=448, out_features=471, bias=True) (BatchNorm_0): BatchNorm1d(471, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Dropout_0): Dropout(p=0.2, inplace=False) (Tanh_0): Tanh() (Layer_1): Linear(in_features=471, out_features=306, bias=True) (BatchNorm_1): BatchNorm1d(306, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Dropout_1): Dropout(p=0.2, inplace=False) (Tanh_1): Tanh() (Layer_2): Linear(in_features=306, out_features=125, bias=True) (BatchNorm_2): BatchNorm1d(125, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Dropout_2): Dropout(p=0.2, inplace=False) (Tanh_2): Tanh() (Layer_3): Linear(in_features=125, out_features=28, bias=True) (BatchNorm_3): BatchNorm1d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Dropout_3): Dropout(p=0.2, inplace=False) (Tanh_3): Tanh() (Output): Linear(in_features=28, out_features=2, bias=True) ) )
You can also provide your own pathways and mapping to create a PathwayNetwork. The PathwayNetwork is what underlies the pruning of the BINN to create sparsity. The pathway file is a standard edgelist. The mapping maps the input to the edgelist.
from binn import PathwayNetwork
mapping = pd.read_csv("../binn/data/downloads/uniprot_2_reactome_2025_01_14.txt", sep="\t")
pathways = pd.read_csv("../binn/data/downloads/reactome_pathways_relation_2025_01_14.txt", sep="\t")
pathways = list(pathways.itertuples(index=False, name=None))
mapping = list(mapping.itertuples(index=False, name=None))
input_entities = input_data["Protein"].tolist()
network = PathwayNetwork(
input_data=input_entities,
pathways=pathways,
mapping=mapping,
)
list(network.pathway_graph.edges())[0]
('R-HSA-109703', 'R-HSA-109704')
# or custom pathways and mapping
mapping = pd.read_csv(
"../binn/data/downloads/uniprot_2_reactome_2025_01_14.txt",
sep="\t",
header=None,
names=["input", "translation", "url", "name", "x", "species"],
)
pathways = pd.read_csv(
"../binn/data/downloads/reactome_pathways_relation_2025_01_14.txt",
sep="\t",
header=None,
names=["target", "source"],
)
binn = BINN(data_matrix=input_data, mapping=mapping, pathways=pathways)
binn.layers
[INFO] BINN is on device: cpu
Sequential( (Layer_0): Linear(in_features=448, out_features=471, bias=True) (BatchNorm_0): BatchNorm1d(471, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Dropout_0): Dropout(p=0, inplace=False) (Tanh_0): Tanh() (Layer_1): Linear(in_features=471, out_features=306, bias=True) (BatchNorm_1): BatchNorm1d(306, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Dropout_1): Dropout(p=0, inplace=False) (Tanh_1): Tanh() (Layer_2): Linear(in_features=306, out_features=125, bias=True) (BatchNorm_2): BatchNorm1d(125, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Dropout_2): Dropout(p=0, inplace=False) (Tanh_2): Tanh() (Layer_3): Linear(in_features=125, out_features=28, bias=True) (BatchNorm_3): BatchNorm1d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Dropout_3): Dropout(p=0, inplace=False) (Tanh_3): Tanh() (Output): Linear(in_features=28, out_features=2, bias=True) )
We can also build an ensemble of heads, in which the output of each layer in the network is passed through a linear layer before being summed in the end.
binn = BINN(
data_matrix=input_data,
network_source="reactome",
heads_ensemble=True,
n_layers=4,
dropout=0.2,
)
binn.layers
[INFO] BINN is on device: cpu
_EnsembleHeads( (blocks): ModuleList( (0): Sequential( (Linear_0): Linear(in_features=448, out_features=471, bias=True) (BatchNorm_0): BatchNorm1d(471, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Tanh_0): Tanh() ) (1): Sequential( (Linear_1): Linear(in_features=471, out_features=306, bias=True) (BatchNorm_1): BatchNorm1d(306, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Tanh_1): Tanh() ) (2): Sequential( (Linear_2): Linear(in_features=306, out_features=125, bias=True) (BatchNorm_2): BatchNorm1d(125, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Tanh_2): Tanh() ) (3): Sequential( (Linear_3): Linear(in_features=125, out_features=28, bias=True) (BatchNorm_3): BatchNorm1d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (Tanh_3): Tanh() ) ) (heads): ModuleList( (0): Sequential( (0): Linear(in_features=471, out_features=2, bias=True) (1): Sigmoid() ) (1): Sequential( (0): Linear(in_features=306, out_features=2, bias=True) (1): Sigmoid() ) (2): Sequential( (0): Linear(in_features=125, out_features=2, bias=True) (1): Sigmoid() ) (3): Sequential( (0): Linear(in_features=28, out_features=2, bias=True) (1): Sigmoid() ) ) )
binn.inputs[0]
np.str_('A0M8Q6')
Looking at the layer names, we see that these correspond to the input and intermediary layers in the model.
layers = binn.layer_names
layers[0][0]
np.str_('A0M8Q6')
Training¶
from binn import BINN, BINNDataLoader, BINNTrainer
import pandas as pd
# Load your data
data_matrix = pd.read_csv("../binn/data/sample_datamatrix.csv")
design_matrix = pd.read_csv("../binn/data/sample_design_matrix.tsv", sep="\t")
# Initialize BINN
binn = BINN(data_matrix=data_matrix, network_source="reactome", n_layers=4, dropout=0.2)
## Initialize DataLoader
binn_dataloader = BINNDataLoader(binn)
# Create DataLoaders
dataloaders = binn_dataloader.create_dataloaders(
data_matrix=data_matrix,
design_matrix=design_matrix,
feature_column="Protein",
group_column="group",
sample_column="sample",
batch_size=32,
validation_split=0.2,
)
# Train the model
trainer = BINNTrainer(binn)
trainer.fit(dataloaders=dataloaders, num_epochs=50)
[INFO] BINN is on device: cpu Mapping group labels: {np.int64(1): 0, np.int64(2): 1} [Epoch 1/50] Train Loss: 0.6418, Train Accuracy: 0.6054 [Epoch 1/50] Val Loss: 0.6931, Val Accuracy: 0.5312 [Epoch 2/50] Train Loss: 0.6577, Train Accuracy: 0.6573 [Epoch 2/50] Val Loss: 0.6929, Val Accuracy: 0.5312 [Epoch 3/50] Train Loss: 0.6903, Train Accuracy: 0.6360 [Epoch 3/50] Val Loss: 0.6925, Val Accuracy: 0.5312 [Epoch 4/50] Train Loss: 0.6185, Train Accuracy: 0.6705 [Epoch 4/50] Val Loss: 0.6921, Val Accuracy: 0.5312 [Epoch 5/50] Train Loss: 0.6848, Train Accuracy: 0.5985 [Epoch 5/50] Val Loss: 0.6916, Val Accuracy: 0.5312 [Epoch 6/50] Train Loss: 0.6150, Train Accuracy: 0.6623 [Epoch 6/50] Val Loss: 0.6906, Val Accuracy: 0.6719 [Epoch 7/50] Train Loss: 0.6657, Train Accuracy: 0.5985 [Epoch 7/50] Val Loss: 0.6872, Val Accuracy: 0.6875 [Epoch 8/50] Train Loss: 0.6146, Train Accuracy: 0.6616 [Epoch 8/50] Val Loss: 0.6794, Val Accuracy: 0.6562 [Epoch 9/50] Train Loss: 0.5771, Train Accuracy: 0.6866 [Epoch 9/50] Val Loss: 0.6634, Val Accuracy: 0.6406 [Epoch 10/50] Train Loss: 0.6273, Train Accuracy: 0.6261 [Epoch 10/50] Val Loss: 0.6365, Val Accuracy: 0.6250 [Epoch 11/50] Train Loss: 0.6453, Train Accuracy: 0.6491 [Epoch 11/50] Val Loss: 0.6013, Val Accuracy: 0.6094 [Epoch 12/50] Train Loss: 0.5779, Train Accuracy: 0.7254 [Epoch 12/50] Val Loss: 0.5644, Val Accuracy: 0.6094 [Epoch 13/50] Train Loss: 0.5509, Train Accuracy: 0.6886 [Epoch 13/50] Val Loss: 0.5365, Val Accuracy: 0.6250 [Epoch 14/50] Train Loss: 0.5646, Train Accuracy: 0.7399 [Epoch 14/50] Val Loss: 0.5061, Val Accuracy: 0.6406 [Epoch 15/50] Train Loss: 0.6014, Train Accuracy: 0.6899 [Epoch 15/50] Val Loss: 0.4778, Val Accuracy: 0.6562 [Epoch 16/50] Train Loss: 0.5455, Train Accuracy: 0.7267 [Epoch 16/50] Val Loss: 0.4539, Val Accuracy: 0.6719 [Epoch 17/50] Train Loss: 0.5388, Train Accuracy: 0.7067 [Epoch 17/50] Val Loss: 0.4343, Val Accuracy: 0.7656 [Epoch 18/50] Train Loss: 0.5539, Train Accuracy: 0.7185 [Epoch 18/50] Val Loss: 0.4200, Val Accuracy: 0.8281 [Epoch 19/50] Train Loss: 0.5271, Train Accuracy: 0.7060 [Epoch 19/50] Val Loss: 0.4084, Val Accuracy: 0.8281 [Epoch 20/50] Train Loss: 0.5402, Train Accuracy: 0.7274 [Epoch 20/50] Val Loss: 0.3927, Val Accuracy: 0.8594 [Epoch 21/50] Train Loss: 0.4913, Train Accuracy: 0.7780 [Epoch 21/50] Val Loss: 0.3846, Val Accuracy: 0.8438 [Epoch 22/50] Train Loss: 0.4765, Train Accuracy: 0.7987 [Epoch 22/50] Val Loss: 0.3799, Val Accuracy: 0.8281 [Epoch 23/50] Train Loss: 0.5221, Train Accuracy: 0.7330 [Epoch 23/50] Val Loss: 0.3813, Val Accuracy: 0.8438 [Epoch 24/50] Train Loss: 0.5363, Train Accuracy: 0.7110 [Epoch 24/50] Val Loss: 0.3797, Val Accuracy: 0.8438 [Epoch 25/50] Train Loss: 0.4313, Train Accuracy: 0.8274 [Epoch 25/50] Val Loss: 0.3781, Val Accuracy: 0.8438 [Epoch 26/50] Train Loss: 0.4942, Train Accuracy: 0.7593 [Epoch 26/50] Val Loss: 0.3714, Val Accuracy: 0.8281 [Epoch 27/50] Train Loss: 0.4628, Train Accuracy: 0.8030 [Epoch 27/50] Val Loss: 0.3701, Val Accuracy: 0.8438 [Epoch 28/50] Train Loss: 0.5169, Train Accuracy: 0.7692 [Epoch 28/50] Val Loss: 0.3647, Val Accuracy: 0.8438 [Epoch 29/50] Train Loss: 0.4432, Train Accuracy: 0.7817 [Epoch 29/50] Val Loss: 0.3614, Val Accuracy: 0.8281 [Epoch 30/50] Train Loss: 0.4575, Train Accuracy: 0.7787 [Epoch 30/50] Val Loss: 0.3685, Val Accuracy: 0.8906 [Epoch 31/50] Train Loss: 0.4333, Train Accuracy: 0.8099 [Epoch 31/50] Val Loss: 0.3616, Val Accuracy: 0.8906 [Epoch 32/50] Train Loss: 0.4628, Train Accuracy: 0.7823 [Epoch 32/50] Val Loss: 0.3494, Val Accuracy: 0.9062 [Epoch 33/50] Train Loss: 0.4940, Train Accuracy: 0.7567 [Epoch 33/50] Val Loss: 0.3428, Val Accuracy: 0.8750 [Epoch 34/50] Train Loss: 0.4694, Train Accuracy: 0.8168 [Epoch 34/50] Val Loss: 0.3310, Val Accuracy: 0.9375 [Epoch 35/50] Train Loss: 0.4325, Train Accuracy: 0.7767 [Epoch 35/50] Val Loss: 0.3372, Val Accuracy: 0.9062 [Epoch 36/50] Train Loss: 0.4716, Train Accuracy: 0.7886 [Epoch 36/50] Val Loss: 0.3479, Val Accuracy: 0.9062 [Epoch 37/50] Train Loss: 0.3970, Train Accuracy: 0.8586 [Epoch 37/50] Val Loss: 0.3442, Val Accuracy: 0.9062 [Epoch 38/50] Train Loss: 0.4235, Train Accuracy: 0.8343 [Epoch 38/50] Val Loss: 0.3405, Val Accuracy: 0.9062 [Epoch 39/50] Train Loss: 0.4820, Train Accuracy: 0.8037 [Epoch 39/50] Val Loss: 0.3304, Val Accuracy: 0.9062 [Epoch 40/50] Train Loss: 0.4097, Train Accuracy: 0.8287 [Epoch 40/50] Val Loss: 0.3296, Val Accuracy: 0.9062 [Epoch 41/50] Train Loss: 0.3964, Train Accuracy: 0.8231 [Epoch 41/50] Val Loss: 0.3088, Val Accuracy: 0.9219 [Epoch 42/50] Train Loss: 0.3760, Train Accuracy: 0.8405 [Epoch 42/50] Val Loss: 0.3109, Val Accuracy: 0.9062 [Epoch 43/50] Train Loss: 0.3940, Train Accuracy: 0.8254 [Epoch 43/50] Val Loss: 0.3070, Val Accuracy: 0.9062 [Epoch 44/50] Train Loss: 0.3803, Train Accuracy: 0.8530 [Epoch 44/50] Val Loss: 0.3104, Val Accuracy: 0.9062 [Epoch 45/50] Train Loss: 0.3934, Train Accuracy: 0.8274 [Epoch 45/50] Val Loss: 0.3083, Val Accuracy: 0.9062 [Epoch 46/50] Train Loss: 0.4165, Train Accuracy: 0.8017 [Epoch 46/50] Val Loss: 0.2965, Val Accuracy: 0.9219 [Epoch 47/50] Train Loss: 0.4264, Train Accuracy: 0.8011 [Epoch 47/50] Val Loss: 0.2967, Val Accuracy: 0.9219 [Epoch 48/50] Train Loss: 0.3473, Train Accuracy: 0.8418 [Epoch 48/50] Val Loss: 0.2932, Val Accuracy: 0.9219 [Epoch 49/50] Train Loss: 0.4035, Train Accuracy: 0.8149 [Epoch 49/50] Val Loss: 0.2805, Val Accuracy: 0.9219 [Epoch 50/50] Train Loss: 0.4195, Train Accuracy: 0.7892 [Epoch 50/50] Val Loss: 0.2777, Val Accuracy: 0.9219