Importance Network & Visualization

ImportanceNetwork

A class for building and analyzing a directed graph representing the importance network of a system.

Parameters:

Name Type Description Default
importance_df DataFrame

A dataframe with columns for the source node, target node, value flow between nodes, and layer for each node. This dataframe should represent the complete importance network of the system.

required
val_col str

The name of the column in the DataFrame that represents the value flow between nodes. Defaults to "value".

'value'
norm_method str

Method to normalie the value column with. Options are 'subgraph' and 'fan'. If 'subgraph', normalizes on the log(nodes in subgraph) from each node. If 'fan', normalizes on the log(fan in + fan out) for each node.

'subgraph'

Attributes:

Name Type Description
importance_df DataFrame

The dataframe used for downstream/upstream subgraph construction and plotting.

val_col str

The name of the column in the DataFrame that represents the value flow between nodes.

G DiGraph

A directed graph object representing the importance network of the system.

G_reverse DiGraph

A directed graph object representing the importance network of the system in reverse.

norm_method str

The normalization method.

Source code in binn/importance_network.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
class ImportanceNetwork:
    """
    A class for building and analyzing a directed graph representing the importance network of a system.

    Parameters:
        importance_df (pandas.DataFrame): A dataframe with columns for the source node, target node, value flow between nodes,
            and layer for each node. This dataframe should represent the complete importance network of the system.
        val_col (str, optional): The name of the column in the DataFrame that represents the value flow between
            nodes. Defaults to "value".
        norm_method (str, optional): Method to normalie the value column with. Options are 'subgraph' and 'fan'.
            If 'subgraph', normalizes on the log(nodes in subgraph) from each node. If 'fan', normalizes on the
            log(fan in + fan out) for each node.

    Attributes:
        importance_df (pandas.DataFrame): The dataframe used for downstream/upstream subgraph construction and plotting.
        val_col (str): The name of the column in the DataFrame that represents the value flow between nodes.
        G (networkx.DiGraph): A directed graph object representing the importance network of the system.
        G_reverse (networkx.DiGraph): A directed graph object representing the importance network of the system in reverse.
        norm_method (str): The normalization method.

    """

    def __init__(
        self,
        importance_df: pd.DataFrame,
        norm_method: str = "subgraph",
        val_col: str = "value",
    ):
        self.root_node = 0
        self.importance_df = importance_df
        self.val_col = val_col
        self.importance_graph = self.create_graph()
        self.importance_graph_reverse = self.importance_graph.reverse()
        self.norm_method = norm_method
        if norm_method:
            self.importance_df = self.add_normalization(method=norm_method)

    def plot_subgraph_sankey(
        self,
        query_node: str,
        upstream: bool = False,
        savename: str = "sankey.png",
        val_col: str = "value",
        cmap: str = "coolwarm",
        width: int = 1200,
        scale: float = 2.5,
        height: int = 500,
    ):
        """
        Generate a Sankey diagram using the provided query node.

        Args:
            query_node (str): The node to use as the starting point for the Sankey diagram.
            upstream (bool, optional): If True, the Sankey diagram will show the upstream flow of the
                query_node. If False (default), the Sankey diagram will show the downstream flow of the
                query_node.
            savename (str, optional): The file name to save the Sankey diagram as. Defaults to "sankey.png".
            val_col (str, optional): The column in the DataFrame that represents the value flow between
                nodes. Defaults to "value".
            cmap_name (str, optional): The name of the color map to use for the Sankey diagram. Defaults
                to "coolwarm".

        Returns:
            plotly.graph_objs._figure.Figure: The plotly Figure object representing the Sankey diagram.

        """
        if upstream:
            final_node_id = self.get_node(query_node)
            subgraph = self.get_upstream_subgraph(final_node_id, depth_limit=None)
            source_or_target = "target"
        else:
            if query_node == "root":
                return ValueError("You cannot look downstream from root")
            final_node_id = self.get_node("root")
            query_node_id = self.get_node(query_node)
            subgraph = self.get_downstream_subgraph(self.importance_graph, query_node_id, depth_limit=None)
            source_or_target = "source"

        nodes_in_subgraph = [n for n in subgraph.nodes]

        df = self.importance_df[
            self.importance_df[source_or_target].isin(nodes_in_subgraph)
        ].copy()

        if df.empty:
            return ValueError("There are no nodes in the specified subgraph")

        fig = subgraph_sankey(
            df, final_node=final_node_id, val_col=val_col, cmap_name=cmap
        )

        fig.write_image(f"{savename}", width=width, scale=scale, height=height)
        return fig

    def plot_complete_sankey(
        self,
        multiclass: bool = False,
        show_top_n: int = 10,
        node_cmap: str = "Reds",
        edge_cmap: Union[str, list] = "Reds",
        savename="sankey.png",
        width: int = 1900,
        scale: float = 2,
        height: int = 800,
    ):
        """
        Plot a complete Sankey diagram for the importance network.

        Parameters:
            multiclass : bool, optional
                If True, plot multiclass Sankey diagram. Defaults to False.
            show_top_n : int, optional
                Show only the top N nodes in the Sankey diagram. Defaults to 10.
            node_cmap : str, optional
                The color map for the nodes. Defaults to "Reds".
            edge_cmap : str or list, optional
                The color map for the edges. Defaults to "Reds".
            savename : str, optional
                The filename to save the plot. Defaults to "sankey.png".

        Returns:
            plotly.graph_objs._figure.Figure: The plotly Figure object representing the Sankey diagram.
        """

        fig = complete_sankey(
            self.importance_df,
            multiclass=multiclass,
            val_col=self.val_col,
            show_top_n=show_top_n,
            edge_cmap=edge_cmap,
            node_cmap=node_cmap,
        )

        fig.write_image(f"{savename}", width=width, scale=scale, height=height)
        return fig

    def create_graph(self):
        """
        Create a directed graph (DiGraph) from the source and target nodes in the input dataframe.

        Returns:
            importance_graph: a directed graph (DiGraph) object
        """
        importance_graph = nx.DiGraph()
        for k in self.importance_df.iterrows():
            source_name = k[1]["source name"]
            source = k[1]["source"]
            value = k[1][self.val_col]
            source_layer = k[1]["source layer"] + 1
            importance_graph.add_node(
                source, weight=value, layer=source_layer, name=source_name
            )
        for k in self.importance_df.iterrows():
            source = k[1]["source"]
            target = k[1]["target"]
            importance_graph.add_edge(source, target)
        root_layer = max(self.importance_df["target layer"]) + 1
        importance_graph.add_node(
            self.root_node, weight=0, layer=root_layer, name="root"
        )
        return importance_graph

    def get_downstream_subgraph(self, graph, query_node: str, depth_limit=None):
        """
        Get a subgraph that contains all nodes downstream of the query node up to the given depth limit (if provided).

        Args:
            query_node: a string representing the name of the node from which the downstream subgraph is constructed
            depth_limit: an integer representing the maximum depth to which the subgraph is constructed (optional)

        Returns:
            subgraph: a directed graph (DiGraph) object containing all nodes downstream of the query node, up to the given depth limit
        """
        subgraph = nx.DiGraph()
        nodes = [
            n
            for n in nx.traversal.bfs_successors(
                graph, query_node, depth_limit=depth_limit
            )
            if n != query_node
        ]
        for source, targets in nodes:
            subgraph.add_node(source, **graph.nodes()[source])
            for t in targets:
                subgraph.add_node(t, **graph.nodes()[t])
        for node1 in subgraph.nodes():
            for node2 in subgraph.nodes():
                if graph.has_edge(node1, node2):
                    subgraph.add_edge(node1, node2)

        subgraph.add_node(query_node, **graph.nodes()[query_node])
        return subgraph

    def get_upstream_subgraph(self, query_node: str, depth_limit=None):
        """
        Get a subgraph that contains all nodes upstream of the query node up to the given depth limit (if provided).

        Args:
            query_node: a string representing the name of the node from which the upstream subgraph is constructed
            depth_limit: an integer representing the maximum depth to which the subgraph is constructed (optional)

        Returns:
            subgraph: a directed graph (DiGraph) object containing all nodes upstream of the query node, up to the given depth limit
        """
        subgraph = self.get_downstream_subgraph(self.importance_graph_reverse, query_node, depth_limit=depth_limit)
        return subgraph

    def get_complete_subgraph(self, query_node: str, depth_limit=None):
        """
        Get a subgraph that contains all nodes both upstream and downstream of the query node up to the given depth limit (if provided).

        Args:
            query_node: a string representing the name of the node from which the complete subgraph is constructed
            depth_limit: an integer representing the maximum depth to which the subgraph is constructed (optional)

        Returns:
            subgraph: a directed graph (DiGraph) object containing all nodes both upstream and downstream of the query node, up to the given depth limit
        """
        subgraph = self.get_downstream_subgraph(self.importance_graph, query_node, depth_limit=depth_limit)
        nodes = [
            n
            for n in nx.traversal.bfs_successors(
                self.importance_graph_reverse, query_node, depth_limit=depth_limit
            )
            if n != query_node
        ]
        for source, targets in nodes:
            subgraph.add_node(source, **self.importance_graph_reverse.nodes()[source])
            for t in targets:
                subgraph.add_node(t, **self.importance_graph_reverse.nodes()[t])
        for node1 in subgraph.nodes():
            for node2 in subgraph.nodes():
                if self.importance_graph_reverse.has_edge(node1, node2):
                    subgraph.add_edge(node1, node2)

        subgraph.add_node(
            query_node, **self.importance_graph_reverse.nodes()[query_node]
        )
        return subgraph

    def get_nr_nodes_in_upstream_subgraph(self, query_node: str):
        """
        Get the number of nodes in the upstream subgraph of the query node.

        Args:
            query_node: a string representing the name of the node from which the upstream subgraph is constructed

        Returns:
            the number of nodes in the upstream subgraph of the query node
        """

        subgraph = self.get_upstream_subgraph(query_node, depth_limit=None)
        return subgraph.number_of_nodes()

    def get_nr_nodes_in_downstream_subgraph(self, query_node: str):
        """
        Get the number of nodes in the downstream subgraph of the query node.

        Args:
            query_node: a string representing the name of the node from which the downstream subgraph is constructed

        Returns:
            the number of nodes in the downstream subgraph of the query node
        """
        subgraph = self.get_downstream_subgraph(self.importance_graph, query_node, depth_limit=None)
        return subgraph.number_of_nodes()

    def get_fan_in(self, query_node: str):
        """
        Get the number of incoming edges (fan-in) for the query node.

        Args:
            query_node: a string representing the name of the node

        Returns:
            the number of incoming edges (fan-in) for the query node
        """
        return len([n for n in self.importance_graph.in_edges(query_node)])

    def get_fan_out(self, query_node: str):
        """
        Get the number of outgoing edges (fan-out) for the query node.

        Args:
            query_node: a string representing the name of the node

        Returns:
            the number of outgoing edges (fan-out) for the query node
        """
        return len([n for n in self.importance_graph.out_edges(query_node)])

    def add_normalization(self, method: str = "subgraph"):
        """
        Adds normalization to the importance values based on the specified method.

        Args:
            method (str): The normalization method to use. Options are "fan" and "subgraph".
                "fan" normalizes based on fan-in and fan-out values.
                "subgraph" normalizes based on the number of nodes in the upstream and
                downstream subgraphs.

        Returns:
            pd.DataFrame: The importance dataframe with the normalized values added.
        """
        if method == "fan":
            fan_in = np.array([self.get_fan_in(x) for x in self.importance_df["source"]])
            fan_out = np.array([self.get_fan_out(x) for x in self.importance_df["source"]])
            nr_tot = fan_in + fan_out + 1
        if method == "subgraph":
            upstream_nodes = np.array(
                [
                    self.get_nr_nodes_in_upstream_subgraph(x)
                    for x in self.importance_df["source"]
                ]
            )
            downstream_nodes = np.array(
                [
                    self.get_nr_nodes_in_downstream_subgraph(x)
                    for x in self.importance_df["source"]
                ]
            )
            nr_tot = upstream_nodes + downstream_nodes

        self.importance_df["value"] = self.importance_df["value"] / np.log2(nr_tot)
        return self.importance_df

    def get_node(self, name):
        for node, d in self.importance_graph.nodes(data=True):
            if d["name"] == name:
                return node
        raise ValueError(f"Could not find node {name}")

add_normalization(method='subgraph')

Adds normalization to the importance values based on the specified method.

Parameters:

Name Type Description Default
method str

The normalization method to use. Options are "fan" and "subgraph". "fan" normalizes based on fan-in and fan-out values. "subgraph" normalizes based on the number of nodes in the upstream and downstream subgraphs.

'subgraph'

Returns:

Type Description

pd.DataFrame: The importance dataframe with the normalized values added.

Source code in binn/importance_network.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def add_normalization(self, method: str = "subgraph"):
    """
    Adds normalization to the importance values based on the specified method.

    Args:
        method (str): The normalization method to use. Options are "fan" and "subgraph".
            "fan" normalizes based on fan-in and fan-out values.
            "subgraph" normalizes based on the number of nodes in the upstream and
            downstream subgraphs.

    Returns:
        pd.DataFrame: The importance dataframe with the normalized values added.
    """
    if method == "fan":
        fan_in = np.array([self.get_fan_in(x) for x in self.importance_df["source"]])
        fan_out = np.array([self.get_fan_out(x) for x in self.importance_df["source"]])
        nr_tot = fan_in + fan_out + 1
    if method == "subgraph":
        upstream_nodes = np.array(
            [
                self.get_nr_nodes_in_upstream_subgraph(x)
                for x in self.importance_df["source"]
            ]
        )
        downstream_nodes = np.array(
            [
                self.get_nr_nodes_in_downstream_subgraph(x)
                for x in self.importance_df["source"]
            ]
        )
        nr_tot = upstream_nodes + downstream_nodes

    self.importance_df["value"] = self.importance_df["value"] / np.log2(nr_tot)
    return self.importance_df

create_graph()

Create a directed graph (DiGraph) from the source and target nodes in the input dataframe.

Returns:

Name Type Description
importance_graph

a directed graph (DiGraph) object

Source code in binn/importance_network.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def create_graph(self):
    """
    Create a directed graph (DiGraph) from the source and target nodes in the input dataframe.

    Returns:
        importance_graph: a directed graph (DiGraph) object
    """
    importance_graph = nx.DiGraph()
    for k in self.importance_df.iterrows():
        source_name = k[1]["source name"]
        source = k[1]["source"]
        value = k[1][self.val_col]
        source_layer = k[1]["source layer"] + 1
        importance_graph.add_node(
            source, weight=value, layer=source_layer, name=source_name
        )
    for k in self.importance_df.iterrows():
        source = k[1]["source"]
        target = k[1]["target"]
        importance_graph.add_edge(source, target)
    root_layer = max(self.importance_df["target layer"]) + 1
    importance_graph.add_node(
        self.root_node, weight=0, layer=root_layer, name="root"
    )
    return importance_graph

get_complete_subgraph(query_node, depth_limit=None)

Get a subgraph that contains all nodes both upstream and downstream of the query node up to the given depth limit (if provided).

Parameters:

Name Type Description Default
query_node str

a string representing the name of the node from which the complete subgraph is constructed

required
depth_limit

an integer representing the maximum depth to which the subgraph is constructed (optional)

None

Returns:

Name Type Description
subgraph

a directed graph (DiGraph) object containing all nodes both upstream and downstream of the query node, up to the given depth limit

Source code in binn/importance_network.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def get_complete_subgraph(self, query_node: str, depth_limit=None):
    """
    Get a subgraph that contains all nodes both upstream and downstream of the query node up to the given depth limit (if provided).

    Args:
        query_node: a string representing the name of the node from which the complete subgraph is constructed
        depth_limit: an integer representing the maximum depth to which the subgraph is constructed (optional)

    Returns:
        subgraph: a directed graph (DiGraph) object containing all nodes both upstream and downstream of the query node, up to the given depth limit
    """
    subgraph = self.get_downstream_subgraph(self.importance_graph, query_node, depth_limit=depth_limit)
    nodes = [
        n
        for n in nx.traversal.bfs_successors(
            self.importance_graph_reverse, query_node, depth_limit=depth_limit
        )
        if n != query_node
    ]
    for source, targets in nodes:
        subgraph.add_node(source, **self.importance_graph_reverse.nodes()[source])
        for t in targets:
            subgraph.add_node(t, **self.importance_graph_reverse.nodes()[t])
    for node1 in subgraph.nodes():
        for node2 in subgraph.nodes():
            if self.importance_graph_reverse.has_edge(node1, node2):
                subgraph.add_edge(node1, node2)

    subgraph.add_node(
        query_node, **self.importance_graph_reverse.nodes()[query_node]
    )
    return subgraph

get_downstream_subgraph(graph, query_node, depth_limit=None)

Get a subgraph that contains all nodes downstream of the query node up to the given depth limit (if provided).

Parameters:

Name Type Description Default
query_node str

a string representing the name of the node from which the downstream subgraph is constructed

required
depth_limit

an integer representing the maximum depth to which the subgraph is constructed (optional)

None

Returns:

Name Type Description
subgraph

a directed graph (DiGraph) object containing all nodes downstream of the query node, up to the given depth limit

Source code in binn/importance_network.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def get_downstream_subgraph(self, graph, query_node: str, depth_limit=None):
    """
    Get a subgraph that contains all nodes downstream of the query node up to the given depth limit (if provided).

    Args:
        query_node: a string representing the name of the node from which the downstream subgraph is constructed
        depth_limit: an integer representing the maximum depth to which the subgraph is constructed (optional)

    Returns:
        subgraph: a directed graph (DiGraph) object containing all nodes downstream of the query node, up to the given depth limit
    """
    subgraph = nx.DiGraph()
    nodes = [
        n
        for n in nx.traversal.bfs_successors(
            graph, query_node, depth_limit=depth_limit
        )
        if n != query_node
    ]
    for source, targets in nodes:
        subgraph.add_node(source, **graph.nodes()[source])
        for t in targets:
            subgraph.add_node(t, **graph.nodes()[t])
    for node1 in subgraph.nodes():
        for node2 in subgraph.nodes():
            if graph.has_edge(node1, node2):
                subgraph.add_edge(node1, node2)

    subgraph.add_node(query_node, **graph.nodes()[query_node])
    return subgraph

get_fan_in(query_node)

Get the number of incoming edges (fan-in) for the query node.

Parameters:

Name Type Description Default
query_node str

a string representing the name of the node

required

Returns:

Type Description

the number of incoming edges (fan-in) for the query node

Source code in binn/importance_network.py
275
276
277
278
279
280
281
282
283
284
285
def get_fan_in(self, query_node: str):
    """
    Get the number of incoming edges (fan-in) for the query node.

    Args:
        query_node: a string representing the name of the node

    Returns:
        the number of incoming edges (fan-in) for the query node
    """
    return len([n for n in self.importance_graph.in_edges(query_node)])

get_fan_out(query_node)

Get the number of outgoing edges (fan-out) for the query node.

Parameters:

Name Type Description Default
query_node str

a string representing the name of the node

required

Returns:

Type Description

the number of outgoing edges (fan-out) for the query node

Source code in binn/importance_network.py
287
288
289
290
291
292
293
294
295
296
297
def get_fan_out(self, query_node: str):
    """
    Get the number of outgoing edges (fan-out) for the query node.

    Args:
        query_node: a string representing the name of the node

    Returns:
        the number of outgoing edges (fan-out) for the query node
    """
    return len([n for n in self.importance_graph.out_edges(query_node)])

get_nr_nodes_in_downstream_subgraph(query_node)

Get the number of nodes in the downstream subgraph of the query node.

Parameters:

Name Type Description Default
query_node str

a string representing the name of the node from which the downstream subgraph is constructed

required

Returns:

Type Description

the number of nodes in the downstream subgraph of the query node

Source code in binn/importance_network.py
262
263
264
265
266
267
268
269
270
271
272
273
def get_nr_nodes_in_downstream_subgraph(self, query_node: str):
    """
    Get the number of nodes in the downstream subgraph of the query node.

    Args:
        query_node: a string representing the name of the node from which the downstream subgraph is constructed

    Returns:
        the number of nodes in the downstream subgraph of the query node
    """
    subgraph = self.get_downstream_subgraph(self.importance_graph, query_node, depth_limit=None)
    return subgraph.number_of_nodes()

get_nr_nodes_in_upstream_subgraph(query_node)

Get the number of nodes in the upstream subgraph of the query node.

Parameters:

Name Type Description Default
query_node str

a string representing the name of the node from which the upstream subgraph is constructed

required

Returns:

Type Description

the number of nodes in the upstream subgraph of the query node

Source code in binn/importance_network.py
248
249
250
251
252
253
254
255
256
257
258
259
260
def get_nr_nodes_in_upstream_subgraph(self, query_node: str):
    """
    Get the number of nodes in the upstream subgraph of the query node.

    Args:
        query_node: a string representing the name of the node from which the upstream subgraph is constructed

    Returns:
        the number of nodes in the upstream subgraph of the query node
    """

    subgraph = self.get_upstream_subgraph(query_node, depth_limit=None)
    return subgraph.number_of_nodes()

get_upstream_subgraph(query_node, depth_limit=None)

Get a subgraph that contains all nodes upstream of the query node up to the given depth limit (if provided).

Parameters:

Name Type Description Default
query_node str

a string representing the name of the node from which the upstream subgraph is constructed

required
depth_limit

an integer representing the maximum depth to which the subgraph is constructed (optional)

None

Returns:

Name Type Description
subgraph

a directed graph (DiGraph) object containing all nodes upstream of the query node, up to the given depth limit

Source code in binn/importance_network.py
201
202
203
204
205
206
207
208
209
210
211
212
213
def get_upstream_subgraph(self, query_node: str, depth_limit=None):
    """
    Get a subgraph that contains all nodes upstream of the query node up to the given depth limit (if provided).

    Args:
        query_node: a string representing the name of the node from which the upstream subgraph is constructed
        depth_limit: an integer representing the maximum depth to which the subgraph is constructed (optional)

    Returns:
        subgraph: a directed graph (DiGraph) object containing all nodes upstream of the query node, up to the given depth limit
    """
    subgraph = self.get_downstream_subgraph(self.importance_graph_reverse, query_node, depth_limit=depth_limit)
    return subgraph

plot_complete_sankey(multiclass=False, show_top_n=10, node_cmap='Reds', edge_cmap='Reds', savename='sankey.png', width=1900, scale=2, height=800)

Plot a complete Sankey diagram for the importance network.

Parameters:

Name Type Description Default
multiclass

bool, optional If True, plot multiclass Sankey diagram. Defaults to False.

False
show_top_n

int, optional Show only the top N nodes in the Sankey diagram. Defaults to 10.

10
node_cmap

str, optional The color map for the nodes. Defaults to "Reds".

'Reds'
edge_cmap

str or list, optional The color map for the edges. Defaults to "Reds".

'Reds'
savename

str, optional The filename to save the plot. Defaults to "sankey.png".

'sankey.png'

Returns:

Type Description

plotly.graph_objs._figure.Figure: The plotly Figure object representing the Sankey diagram.

Source code in binn/importance_network.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def plot_complete_sankey(
    self,
    multiclass: bool = False,
    show_top_n: int = 10,
    node_cmap: str = "Reds",
    edge_cmap: Union[str, list] = "Reds",
    savename="sankey.png",
    width: int = 1900,
    scale: float = 2,
    height: int = 800,
):
    """
    Plot a complete Sankey diagram for the importance network.

    Parameters:
        multiclass : bool, optional
            If True, plot multiclass Sankey diagram. Defaults to False.
        show_top_n : int, optional
            Show only the top N nodes in the Sankey diagram. Defaults to 10.
        node_cmap : str, optional
            The color map for the nodes. Defaults to "Reds".
        edge_cmap : str or list, optional
            The color map for the edges. Defaults to "Reds".
        savename : str, optional
            The filename to save the plot. Defaults to "sankey.png".

    Returns:
        plotly.graph_objs._figure.Figure: The plotly Figure object representing the Sankey diagram.
    """

    fig = complete_sankey(
        self.importance_df,
        multiclass=multiclass,
        val_col=self.val_col,
        show_top_n=show_top_n,
        edge_cmap=edge_cmap,
        node_cmap=node_cmap,
    )

    fig.write_image(f"{savename}", width=width, scale=scale, height=height)
    return fig

plot_subgraph_sankey(query_node, upstream=False, savename='sankey.png', val_col='value', cmap='coolwarm', width=1200, scale=2.5, height=500)

Generate a Sankey diagram using the provided query node.

Parameters:

Name Type Description Default
query_node str

The node to use as the starting point for the Sankey diagram.

required
upstream bool

If True, the Sankey diagram will show the upstream flow of the query_node. If False (default), the Sankey diagram will show the downstream flow of the query_node.

False
savename str

The file name to save the Sankey diagram as. Defaults to "sankey.png".

'sankey.png'
val_col str

The column in the DataFrame that represents the value flow between nodes. Defaults to "value".

'value'
cmap_name str

The name of the color map to use for the Sankey diagram. Defaults to "coolwarm".

required

Returns:

Type Description

plotly.graph_objs._figure.Figure: The plotly Figure object representing the Sankey diagram.

Source code in binn/importance_network.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def plot_subgraph_sankey(
    self,
    query_node: str,
    upstream: bool = False,
    savename: str = "sankey.png",
    val_col: str = "value",
    cmap: str = "coolwarm",
    width: int = 1200,
    scale: float = 2.5,
    height: int = 500,
):
    """
    Generate a Sankey diagram using the provided query node.

    Args:
        query_node (str): The node to use as the starting point for the Sankey diagram.
        upstream (bool, optional): If True, the Sankey diagram will show the upstream flow of the
            query_node. If False (default), the Sankey diagram will show the downstream flow of the
            query_node.
        savename (str, optional): The file name to save the Sankey diagram as. Defaults to "sankey.png".
        val_col (str, optional): The column in the DataFrame that represents the value flow between
            nodes. Defaults to "value".
        cmap_name (str, optional): The name of the color map to use for the Sankey diagram. Defaults
            to "coolwarm".

    Returns:
        plotly.graph_objs._figure.Figure: The plotly Figure object representing the Sankey diagram.

    """
    if upstream:
        final_node_id = self.get_node(query_node)
        subgraph = self.get_upstream_subgraph(final_node_id, depth_limit=None)
        source_or_target = "target"
    else:
        if query_node == "root":
            return ValueError("You cannot look downstream from root")
        final_node_id = self.get_node("root")
        query_node_id = self.get_node(query_node)
        subgraph = self.get_downstream_subgraph(self.importance_graph, query_node_id, depth_limit=None)
        source_or_target = "source"

    nodes_in_subgraph = [n for n in subgraph.nodes]

    df = self.importance_df[
        self.importance_df[source_or_target].isin(nodes_in_subgraph)
    ].copy()

    if df.empty:
        return ValueError("There are no nodes in the specified subgraph")

    fig = subgraph_sankey(
        df, final_node=final_node_id, val_col=val_col, cmap_name=cmap
    )

    fig.write_image(f"{savename}", width=width, scale=scale, height=height)
    return fig