Scikit-learn API
BINNClassifier
Bases: BaseEstimator
, ClassifierMixin
A sci-kit learn wrapper for the BINN.
Args:
pathways : Network, optional
The network architecture to use for the classifier. If None, a default
architecture will be used. Default is None.
activation : str, optional
The activation function to use for the classifier. Default is 'tanh'.
weight : torch.Tensor, optional
The weight to assign to each class. Default is torch.Tensor([1, 1]).
learning_rate : float, optional
The learning rate for the optimizer. Default is 1e-4.
n_layers : int, optional
The number of layers in the network architecture. Default is 4.
scheduler : str, optional
The scheduler to use for the optimizer. Default is 'plateau'.
optimizer : str, optional
The optimizer to use for training. Default is 'adam'.
n_outputs : int, optional
The number of outputs of the network architecture. Default is 2.
dropout : float, optional
The dropout rate to use for the classifier. Default is 0.
residual : bool, optional
Whether to use residual connections in the network architecture.
Default is False.
threads : int, optional
The number of threads to use for data loading. Default is 1.
epochs : int, optional
The number of epochs to train the classifier for. Default is 100.
logger : Union[SuperLogger, None], optional
The logger to use for logging training information. Default is None.
log_steps : int, optional
The number of steps between each log message during training.
Default is 50.
Attributes:
Name | Type | Description |
---|---|---|
clf |
BINN The BINN (Block Independent Neural Network) instance used for classification. |
|
threads |
int The number of threads used for data loading. |
|
epochs |
int The number of epochs to train the classifier for. |
|
logger |
Union[SuperLogger, None] The logger used for logging training information. |
|
log_steps |
int The number of steps between each log message during training. |
Source code in binn/sklearn.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
|
fit(X, y, epochs)
Trains the classifier using the provided input data and target labels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
array-like of shape (n_samples, n_features
|
The input data. |
required |
y |
array-like of shape (n_samples,
|
The target labels. |
required |
Returns:
Type | Description |
---|---|
None |
Source code in binn/sklearn.py
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
|
predict(X)
Predicts target labels for the provided input data using the trained classifier.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
array-like of shape (n_samples, n_features
|
The input data. |
required |
Returns:
Name | Type | Description |
---|---|---|
y_hat |
torch.Tensor of shape (n_samples,)
|
The predicted target labels. |
Source code in binn/sklearn.py
130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
|