BINNTrainer
BINNLogger
A minimal logger for BINN.
Source code in binn/model/trainer.py
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
|
log(phase, metrics)
Log metrics for a specific phase (train/val).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
phase
|
str
|
Phase name ('train' or 'val'). |
required |
metrics
|
dict
|
Dictionary containing metric names and values. |
required |
Source code in binn/model/trainer.py
138 139 140 141 142 143 144 145 146 |
|
save_logs()
Save logs to disk as a CSV file.
Source code in binn/model/trainer.py
148 149 150 151 152 153 154 155 156 |
|
BINNTrainer
Handles training BINN models using a raw PyTorch training loop.
Source code in binn/model/trainer.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
|
__init__(binn_model, save_dir='')
Parameters:
Name | Type | Description | Default |
---|---|---|---|
binn_model
|
The BINN model instance to train. |
required | |
save_dir
|
str
|
Directory to save logs and/or checkpoints. |
''
|
Source code in binn/model/trainer.py
10 11 12 13 14 15 16 17 18 |
|
evaluate(dataloader)
Evaluate the BINN model on a dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataloader
|
DataLoader
|
DataLoader for the evaluation dataset. |
required |
Returns:
Name | Type | Description |
---|---|---|
dict |
A dictionary with 'loss' and 'accuracy' on the evaluation set. |
Source code in binn/model/trainer.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
|
fit(dataloaders, num_epochs=30, learning_rate=0.0001, checkpoint_path=None)
Train the BINN model using a standard PyTorch training loop.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataloaders
|
dict
|
Dictionary containing: - "train": DataLoader for training data. - "val" (optional): DataLoader for validation data. |
required |
num_epochs
|
int
|
Number of training epochs. |
30
|
learning_rate
|
float
|
Learning rate for the optimizer. |
0.0001
|
checkpoint_path
|
str
|
Path to save model checkpoints (optional). |
None
|
Source code in binn/model/trainer.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|