In [16]:
Copied!
from binn import Network, BINN
import pandas as pd
input_data = pd.read_csv("../data/test_qm.csv")
translation = pd.read_csv("../data/translation.tsv", sep="\t")
pathways = pd.read_csv("../data/pathways.tsv", sep="\t")
network = Network(
input_data=input_data,
pathways=pathways,
mapping=translation,
source_column="child",
target_column="parent",
)
binn = BINN(
network=network,
n_layers=4,
dropout=0.2,
validate=False,
residual=False,
device="cpu",
learning_rate=0.001
)
from binn import Network, BINN
import pandas as pd
input_data = pd.read_csv("../data/test_qm.csv")
translation = pd.read_csv("../data/translation.tsv", sep="\t")
pathways = pd.read_csv("../data/pathways.tsv", sep="\t")
network = Network(
input_data=input_data,
pathways=pathways,
mapping=translation,
source_column="child",
target_column="parent",
)
binn = BINN(
network=network,
n_layers=4,
dropout=0.2,
validate=False,
residual=False,
device="cpu",
learning_rate=0.001
)
BINN is on the device: cpu
In [17]:
Copied!
from util_for_examples import fit_data_matrix_to_network_input, generate_data
import torch
from lightning.pytorch import Trainer
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")
protein_matrix = fit_data_matrix_to_network_input(input_data, features=network.inputs)
X, y = generate_data(protein_matrix, design_matrix=design_matrix)
dataset = torch.utils.data.TensorDataset(
torch.tensor(X, dtype=torch.float32, device=binn.device),
torch.tensor(y, dtype=torch.int16, device=binn.device),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)
# You can train using the Lightning Trainer
trainer = Trainer(max_epochs=10, log_every_n_steps=10)
#trainer.fit(binn, dataloader)
from util_for_examples import fit_data_matrix_to_network_input, generate_data
import torch
from lightning.pytorch import Trainer
design_matrix = pd.read_csv("../data/design_matrix.tsv", sep="\t")
protein_matrix = fit_data_matrix_to_network_input(input_data, features=network.inputs)
X, y = generate_data(protein_matrix, design_matrix=design_matrix)
dataset = torch.utils.data.TensorDataset(
torch.tensor(X, dtype=torch.float32, device=binn.device),
torch.tensor(y, dtype=torch.int16, device=binn.device),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)
# You can train using the Lightning Trainer
trainer = Trainer(max_epochs=10, log_every_n_steps=10)
#trainer.fit(binn, dataloader)
GPU available: True (mps), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs
Since training with the Lightning.Trainer
is slow (since new workers are created for each epoch), we can implement our own training-loop in a standard PyTorch-fashion.
In [18]:
Copied!
import torch.nn.functional as F
# You can also train with a standard PyTorch train loop
optimizer = binn.configure_optimizers()[0][0]
num_epochs = 30
for epoch in range(num_epochs):
binn.train()
total_loss = 0.0
total_accuracy = 0
for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs = inputs.to(binn.device)
targets = targets.to(binn.device).type(torch.LongTensor)
optimizer.zero_grad()
outputs = binn(inputs).to(binn.device)
loss = F.cross_entropy(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_accuracy += torch.sum(torch.argmax(outputs, axis=1) == targets) / len(targets)
avg_loss = total_loss / len(dataloader)
avg_accuracy = total_accuracy / len(dataloader)
print(f'Epoch {epoch}, Average Accuracy {avg_accuracy}, Average Loss: {avg_loss}')
import torch.nn.functional as F
# You can also train with a standard PyTorch train loop
optimizer = binn.configure_optimizers()[0][0]
num_epochs = 30
for epoch in range(num_epochs):
binn.train()
total_loss = 0.0
total_accuracy = 0
for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs = inputs.to(binn.device)
targets = targets.to(binn.device).type(torch.LongTensor)
optimizer.zero_grad()
outputs = binn(inputs).to(binn.device)
loss = F.cross_entropy(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_accuracy += torch.sum(torch.argmax(outputs, axis=1) == targets) / len(targets)
avg_loss = total_loss / len(dataloader)
avg_accuracy = total_accuracy / len(dataloader)
print(f'Epoch {epoch}, Average Accuracy {avg_accuracy}, Average Loss: {avg_loss}')
Epoch 0, Average Accuracy 0.5519999861717224, Average Loss: 0.7984879660606384 Epoch 1, Average Accuracy 0.7119999527931213, Average Loss: 0.5876554298400879 Epoch 2, Average Accuracy 0.7699999809265137, Average Loss: 0.5000520890951157 Epoch 3, Average Accuracy 0.8059999942779541, Average Loss: 0.4410453510284424 Epoch 4, Average Accuracy 0.7699999809265137, Average Loss: 0.48531828820705414 Epoch 5, Average Accuracy 0.8100000023841858, Average Loss: 0.38955733954906463 Epoch 6, Average Accuracy 0.7809999585151672, Average Loss: 0.4015652185678482 Epoch 7, Average Accuracy 0.7890000343322754, Average Loss: 0.4215572929382324 Epoch 8, Average Accuracy 0.8889999985694885, Average Loss: 0.31606161445379255 Epoch 9, Average Accuracy 0.8769999742507935, Average Loss: 0.3141865438222885 Epoch 10, Average Accuracy 0.8849999904632568, Average Loss: 0.31385430574417117 Epoch 11, Average Accuracy 0.871999979019165, Average Loss: 0.2942655232548714 Epoch 12, Average Accuracy 0.8399999737739563, Average Loss: 0.2976784099638462 Epoch 13, Average Accuracy 0.875, Average Loss: 0.25456758081912995 Epoch 14, Average Accuracy 0.9519999623298645, Average Loss: 0.1938611924648285 Epoch 15, Average Accuracy 0.9049999713897705, Average Loss: 0.2302371022105217 Epoch 16, Average Accuracy 0.9519999623298645, Average Loss: 0.1974144048988819 Epoch 17, Average Accuracy 0.8889999985694885, Average Loss: 0.22772141829133033 Epoch 18, Average Accuracy 0.9300000071525574, Average Loss: 0.21292504481971264 Epoch 19, Average Accuracy 0.925000011920929, Average Loss: 0.19984641671180725 Epoch 20, Average Accuracy 0.8640000224113464, Average Loss: 0.30628247559070587 Epoch 21, Average Accuracy 0.949999988079071, Average Loss: 0.18478416368365289 Epoch 22, Average Accuracy 0.8919999599456787, Average Loss: 0.2577762907743454 Epoch 23, Average Accuracy 0.9300000071525574, Average Loss: 0.19484564393758774 Epoch 24, Average Accuracy 0.9549999833106995, Average Loss: 0.17850346982479096 Epoch 25, Average Accuracy 0.9369999766349792, Average Loss: 0.18940596178174018 Epoch 26, Average Accuracy 0.909000039100647, Average Loss: 0.21654408738017084 Epoch 27, Average Accuracy 0.9449999928474426, Average Loss: 0.16795408174395562 Epoch 28, Average Accuracy 0.9369999766349792, Average Loss: 0.1838543491065502 Epoch 29, Average Accuracy 0.9350000023841858, Average Loss: 0.1619287095963955
In [19]:
Copied!
from binn import BINNExplainer
explainer = BINNExplainer(binn)
from binn import BINNExplainer
explainer = BINNExplainer(binn)
In [20]:
Copied!
test_data = torch.Tensor(X[5:10])
background_data = torch.Tensor(X[0:5])
importance_df = explainer.explain(test_data, background_data)
importance_df.head()
test_data = torch.Tensor(X[5:10])
background_data = torch.Tensor(X[0:5])
importance_df = explainer.explain(test_data, background_data)
importance_df.head()
Out[20]:
source | target | source name | target name | value | type | source layer | target layer | |
---|---|---|---|---|---|---|---|---|
0 | 1 | 497 | A0M8Q6 | R-HSA-166663 | 0.0 | 0 | 0 | 1 |
1 | 1 | 497 | A0M8Q6 | R-HSA-166663 | 0.0 | 1 | 0 | 1 |
2 | 1 | 954 | A0M8Q6 | R-HSA-198933 | 0.0 | 0 | 0 | 1 |
3 | 1 | 954 | A0M8Q6 | R-HSA-198933 | 0.0 | 1 | 0 | 1 |
4 | 1 | 539 | A0M8Q6 | R-HSA-2029481 | 0.0 | 0 | 0 | 1 |
In [21]:
Copied!
from binn import ImportanceNetwork
IG = ImportanceNetwork(importance_df, norm_method="fan")
from binn import ImportanceNetwork
IG = ImportanceNetwork(importance_df, norm_method="fan")
In [22]:
Copied!
IG.plot_complete_sankey(
multiclass=False, savename="img/complete_sankey.png", node_cmap="Accent_r", edge_cmap="Accent_r"
)
IG.plot_complete_sankey(
multiclass=False, savename="img/complete_sankey.png", node_cmap="Accent_r", edge_cmap="Accent_r"
)
In [23]:
Copied!
query_node = "R-HSA-597592"
IG.plot_subgraph_sankey(
query_node, upstream=True, savename="img/subgraph_sankey.png", cmap="coolwarm"
)
query_node = "R-HSA-597592"
IG.plot_subgraph_sankey(
query_node, upstream=True, savename="img/subgraph_sankey.png", cmap="coolwarm"
)