Robustness¶
In this example we analyze how robust feature selection with BINN is. All feature selection methods are data-greedy, meaning that we need many samples to generate robust estimations of feature importance. One way to analyze this is to train and explain several models, and compare their feature importances.
In [1]:
Copied!
from binn import Network, BINN
import pandas as pd
import numpy as np
import torch
import random
# random seed for reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)
input_data = pd.read_csv("../data/test_qm.csv")
translation = pd.read_csv("../data/translation.tsv", sep="\t")
pathways = pd.read_csv("../data/pathways.tsv", sep="\t")
network = Network(
input_data=input_data,
pathways=pathways,
mapping=translation,
source_column="child",
target_column="parent",
)
binn = BINN(
network=network,
n_layers=4,
dropout=0.2,
validate=False,
residual=False,
device="cpu",
learning_rate =0.001
)
from binn import Network, BINN
import pandas as pd
import numpy as np
import torch
import random
# random seed for reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)
input_data = pd.read_csv("../data/test_qm.csv")
translation = pd.read_csv("../data/translation.tsv", sep="\t")
pathways = pd.read_csv("../data/pathways.tsv", sep="\t")
network = Network(
input_data=input_data,
pathways=pathways,
mapping=translation,
source_column="child",
target_column="parent",
)
binn = BINN(
network=network,
n_layers=4,
dropout=0.2,
validate=False,
residual=False,
device="cpu",
learning_rate =0.001
)
IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
BINN is on the device: cpu
In [2]:
Copied!
from util_for_examples import fit_data_matrix_to_network_input, generate_data
import torch
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")
protein_matrix = fit_data_matrix_to_network_input(input_data, features=network.inputs)
X, y = generate_data(protein_matrix, design_matrix=design_matrix)
dataset = torch.utils.data.TensorDataset(
torch.tensor(X, dtype=torch.float32, device=binn.device),
torch.tensor(y, dtype=torch.int16, device=binn.device),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)
from util_for_examples import fit_data_matrix_to_network_input, generate_data
import torch
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")
protein_matrix = fit_data_matrix_to_network_input(input_data, features=network.inputs)
X, y = generate_data(protein_matrix, design_matrix=design_matrix)
dataset = torch.utils.data.TensorDataset(
torch.tensor(X, dtype=torch.float32, device=binn.device),
torch.tensor(y, dtype=torch.int16, device=binn.device),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)
In [3]:
Copied!
from binn.explainer import BINNExplainer
explainer = BINNExplainer(binn)
from binn.explainer import BINNExplainer
explainer = BINNExplainer(binn)
Explain average¶
In exlpainer.explain_average()
we can train the n models and explain them separately. Their resulting feature importances are concatenated to the outputted ìmportance_df
.
In [4]:
Copied!
test_data = torch.Tensor(X)
background_data = torch.Tensor(X)
n_iterations = 10
n_epochs = 30
importance_df = explainer.explain_average(test_data, background_data, n_iterations, n_epochs, dataloader, fast_train=True)
importance_df.head()
test_data = torch.Tensor(X)
background_data = torch.Tensor(X)
n_iterations = 10
n_epochs = 30
importance_df = explainer.explain_average(test_data, background_data, n_iterations, n_epochs, dataloader, fast_train=True)
importance_df.head()
Iteration 0 Final epoch: Average Accuracy 0.96, Average Loss: 0.14 Iteration 1 Final epoch: Average Accuracy 0.94, Average Loss: 0.15 Iteration 2 Final epoch: Average Accuracy 0.96, Average Loss: 0.12 Iteration 3 Final epoch: Average Accuracy 0.96, Average Loss: 0.14 Iteration 4 Final epoch: Average Accuracy 0.95, Average Loss: 0.15 Iteration 5 Final epoch: Average Accuracy 0.92, Average Loss: 0.19 Iteration 6 Final epoch: Average Accuracy 0.94, Average Loss: 0.14 Iteration 7 Final epoch: Average Accuracy 0.95, Average Loss: 0.15 Iteration 8 Final epoch: Average Accuracy 0.95, Average Loss: 0.12 Iteration 9 Final epoch: Average Accuracy 0.96, Average Loss: 0.15
Out[4]:
source | target | source name | target name | type | source layer | target layer | value_0 | value_1 | value_2 | value_3 | value_4 | value_5 | value_6 | value_7 | value_8 | value_9 | value_mean | values_std | value | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 497 | A0M8Q6 | R-HSA-166663 | 0 | 0 | 1 | 0.003766 | 0.04759 | 0.040599 | 0.037591 | 0.030930 | 0.050208 | 0.044990 | 0.020216 | 0.037221 | 0.067721 | 0.038083 | 0.016489 | 0.038083 |
1 | 1 | 497 | A0M8Q6 | R-HSA-166663 | 1 | 0 | 1 | 0.020309 | 0.04680 | 0.014003 | 0.018245 | 0.056983 | 0.062000 | 0.052162 | 0.033551 | 0.087089 | 0.098307 | 0.048945 | 0.027133 | 0.048945 |
2 | 1 | 954 | A0M8Q6 | R-HSA-198933 | 0 | 0 | 1 | 0.003766 | 0.04759 | 0.040599 | 0.037591 | 0.030930 | 0.050208 | 0.044990 | 0.020216 | 0.037221 | 0.067721 | 0.038083 | 0.016489 | 0.038083 |
3 | 1 | 954 | A0M8Q6 | R-HSA-198933 | 1 | 0 | 1 | 0.020309 | 0.04680 | 0.014003 | 0.018245 | 0.056983 | 0.062000 | 0.052162 | 0.033551 | 0.087089 | 0.098307 | 0.048945 | 0.027133 | 0.048945 |
4 | 1 | 539 | A0M8Q6 | R-HSA-2029481 | 0 | 0 | 1 | 0.003766 | 0.04759 | 0.040599 | 0.037591 | 0.030930 | 0.050208 | 0.044990 | 0.020216 | 0.037221 | 0.067721 | 0.038083 | 0.016489 | 0.038083 |
In [5]:
Copied!
importance_df["copy"] = importance_df.apply(lambda x: True if x["source name"] == x["target name"] else False, axis=1)
importance_df = importance_df[importance_df["copy"] == False]
importance_df
importance_df["copy"] = importance_df.apply(lambda x: True if x["source name"] == x["target name"] else False, axis=1)
importance_df = importance_df[importance_df["copy"] == False]
importance_df
Out[5]:
source | target | source name | target name | type | source layer | target layer | value_0 | value_1 | value_2 | ... | value_4 | value_5 | value_6 | value_7 | value_8 | value_9 | value_mean | values_std | value | copy | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 497 | A0M8Q6 | R-HSA-166663 | 0 | 0 | 1 | 0.003766 | 0.047590 | 0.040599 | ... | 0.030930 | 0.050208 | 0.044990 | 0.020216 | 0.037221 | 0.067721 | 0.038083 | 0.016489 | 0.038083 | False |
1 | 1 | 497 | A0M8Q6 | R-HSA-166663 | 1 | 0 | 1 | 0.020309 | 0.046800 | 0.014003 | ... | 0.056983 | 0.062000 | 0.052162 | 0.033551 | 0.087089 | 0.098307 | 0.048945 | 0.027133 | 0.048945 | False |
2 | 1 | 954 | A0M8Q6 | R-HSA-198933 | 0 | 0 | 1 | 0.003766 | 0.047590 | 0.040599 | ... | 0.030930 | 0.050208 | 0.044990 | 0.020216 | 0.037221 | 0.067721 | 0.038083 | 0.016489 | 0.038083 | False |
3 | 1 | 954 | A0M8Q6 | R-HSA-198933 | 1 | 0 | 1 | 0.020309 | 0.046800 | 0.014003 | ... | 0.056983 | 0.062000 | 0.052162 | 0.033551 | 0.087089 | 0.098307 | 0.048945 | 0.027133 | 0.048945 | False |
4 | 1 | 539 | A0M8Q6 | R-HSA-2029481 | 0 | 0 | 1 | 0.003766 | 0.047590 | 0.040599 | ... | 0.030930 | 0.050208 | 0.044990 | 0.020216 | 0.037221 | 0.067721 | 0.038083 | 0.016489 | 0.038083 | False |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
6901 | 1319 | 0 | R-HSA-9612973 | root | 1 | 4 | 5 | 0.210430 | 0.345282 | 0.247439 | ... | 0.027536 | 0.200686 | 0.034322 | 0.072509 | 0.065099 | 0.055107 | 0.130016 | 0.105911 | 0.130016 | False |
6902 | 1320 | 0 | R-HSA-9709957 | root | 0 | 4 | 5 | 0.326383 | 0.150743 | 0.016936 | ... | 0.020470 | 0.271241 | 0.300268 | 0.062605 | 0.252250 | 0.265667 | 0.177835 | 0.113163 | 0.177835 | False |
6903 | 1320 | 0 | R-HSA-9709957 | root | 1 | 4 | 5 | 0.143852 | 0.296020 | 0.360922 | ... | 0.321963 | 0.252727 | 0.338927 | 0.087586 | 0.328766 | 0.161152 | 0.260301 | 0.090516 | 0.260301 | False |
6904 | 1321 | 0 | R-HSA-9748784 | root | 0 | 4 | 5 | 0.048347 | 0.106638 | 0.033202 | ... | 0.035684 | 0.144263 | 0.139459 | 0.157461 | 0.186026 | 0.096128 | 0.115311 | 0.058665 | 0.115311 | False |
6905 | 1321 | 0 | R-HSA-9748784 | root | 1 | 4 | 5 | 0.152744 | 0.171184 | 0.042447 | ... | 0.046416 | 0.024173 | 0.066416 | 0.053887 | 0.158463 | 0.207614 | 0.098276 | 0.062971 | 0.098276 | False |
6580 rows × 21 columns
In [6]:
Copied!
importance_df_copy = importance_df.groupby(["source name", "source layer", "target layer"], as_index=False).mean(numeric_only=True)
mean_ranks = []
std_ranks = []
source_layer = []
sources = []
for layer in range(binn.n_layers):
layer_df = importance_df_copy[importance_df_copy["source layer"] == layer].copy()
for i in range(n_iterations):
layer_df.sort_values(f"value_{i}", ascending=False, inplace=True)
layer_df[f"rank_{i}"] = range(len(layer_df.index))
rank_cols = [c for c in layer_df.columns if c.startswith("rank")]
mean_ranks += (layer_df[rank_cols].mean(axis=1)/ len(layer_df.index)).tolist()
sources += layer_df["source name"].tolist()
std_ranks += (layer_df[rank_cols].std(axis=1)/ len(layer_df.index)).tolist()
source_layer += layer_df["source layer"].tolist()
plot_df = pd.DataFrame({"mean":mean_ranks, "std":std_ranks, "source layer":source_layer, "source":sources})
importance_df_copy = importance_df.groupby(["source name", "source layer", "target layer"], as_index=False).mean(numeric_only=True)
mean_ranks = []
std_ranks = []
source_layer = []
sources = []
for layer in range(binn.n_layers):
layer_df = importance_df_copy[importance_df_copy["source layer"] == layer].copy()
for i in range(n_iterations):
layer_df.sort_values(f"value_{i}", ascending=False, inplace=True)
layer_df[f"rank_{i}"] = range(len(layer_df.index))
rank_cols = [c for c in layer_df.columns if c.startswith("rank")]
mean_ranks += (layer_df[rank_cols].mean(axis=1)/ len(layer_df.index)).tolist()
sources += layer_df["source name"].tolist()
std_ranks += (layer_df[rank_cols].std(axis=1)/ len(layer_df.index)).tolist()
source_layer += layer_df["source layer"].tolist()
plot_df = pd.DataFrame({"mean":mean_ranks, "std":std_ranks, "source layer":source_layer, "source":sources})
In [7]:
Copied!
plot_df.sort_values("mean").head(20)
plot_df.sort_values("mean").head(20)
Out[7]:
mean | std | source layer | source | |
---|---|---|---|---|
449 | 0.000651 | 0.002060 | 1 | R-HSA-975634 |
3 | 0.001114 | 0.002406 | 0 | P04114 |
2 | 0.009354 | 0.005435 | 0 | P02647 |
759 | 0.009653 | 0.017576 | 2 | R-HSA-975634 |
1 | 0.014477 | 0.027282 | 0 | P60709 |
8 | 0.016036 | 0.011634 | 0 | P00734 |
4 | 0.017817 | 0.008529 | 0 | P06727 |
12 | 0.019822 | 0.015122 | 0 | P02452 |
6 | 0.021826 | 0.010488 | 0 | Q9UBR2 |
0 | 0.026281 | 0.052071 | 0 | P08571 |
29 | 0.027394 | 0.021569 | 0 | P04908 |
451 | 0.028013 | 0.060282 | 1 | R-HSA-5696394 |
11 | 0.028285 | 0.025199 | 0 | P10451 |
42 | 0.036080 | 0.027853 | 0 | P00742 |
7 | 0.037416 | 0.016460 | 0 | P13611 |
10 | 0.037639 | 0.040546 | 0 | P68871 |
1021 | 0.039655 | 0.064152 | 3 | R-HSA-2187338 |
49 | 0.042316 | 0.028541 | 0 | P27797 |
50 | 0.043207 | 0.029866 | 0 | P01042 |
16 | 0.045434 | 0.021598 | 0 | P02656 |
If we now plot the (normalized) mean and standard deviation of the rank, we see that highly important features generally have a low standard deviation. We can therefore be the most certain about the most important features.
In [8]:
Copied!
import seaborn as sns
sns.jointplot(plot_df, y="mean", x="std", hue="source layer", palette="coolwarm")
import seaborn as sns
sns.jointplot(plot_df, y="mean", x="std", hue="source layer", palette="coolwarm")
Out[8]:
<seaborn.axisgrid.JointGrid at 0x2aa140fa0>
In [9]:
Copied!
from binn import ImportanceNetwork
IG = ImportanceNetwork(importance_df, norm_method="fan")
from binn import ImportanceNetwork
IG = ImportanceNetwork(importance_df, norm_method="fan")
In [10]:
Copied!
IG.plot_complete_sankey(
multiclass=False, savename="img/complete_sankey.png", node_cmap="Accent_r", edge_cmap="Accent_r"
)
IG.plot_complete_sankey(
multiclass=False, savename="img/complete_sankey.png", node_cmap="Accent_r", edge_cmap="Accent_r"
)
In [11]:
Copied!
query_node = "R-HSA-597592"
IG.plot_subgraph_sankey(
query_node, upstream=True, savename="img/subgraph_sankey.png", cmap="coolwarm"
)
query_node = "R-HSA-597592"
IG.plot_subgraph_sankey(
query_node, upstream=True, savename="img/subgraph_sankey.png", cmap="coolwarm"
)