BINNTrainer

BINNLogger

A minimal logger for BINN.

Source code in binn/model/trainer.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class BINNLogger:
    """
    A minimal logger for BINN.
    """

    def __init__(self, save_dir):
        self.save_dir = save_dir
        self.logs = {"train": [], "val": []}

    def log(self, phase, metrics):
        """
        Log metrics for a specific phase (train/val).

        Args:
            phase (str): Phase name ('train' or 'val').
            metrics (dict): Dictionary containing metric names and values.
        """
        self.logs[phase].append(metrics)

    def save_logs(self):
        """
        Save logs to disk as a CSV file.
        """
        import pandas as pd
        for phase, log_data in self.logs.items():
            if log_data:
                df = pd.DataFrame(log_data)
                df.to_csv(f"{self.save_dir}/{phase}_logs.csv", index=False)

log(phase, metrics)

Log metrics for a specific phase (train/val).

Parameters:

Name Type Description Default
phase str

Phase name ('train' or 'val').

required
metrics dict

Dictionary containing metric names and values.

required
Source code in binn/model/trainer.py
138
139
140
141
142
143
144
145
146
def log(self, phase, metrics):
    """
    Log metrics for a specific phase (train/val).

    Args:
        phase (str): Phase name ('train' or 'val').
        metrics (dict): Dictionary containing metric names and values.
    """
    self.logs[phase].append(metrics)

save_logs()

Save logs to disk as a CSV file.

Source code in binn/model/trainer.py
148
149
150
151
152
153
154
155
156
def save_logs(self):
    """
    Save logs to disk as a CSV file.
    """
    import pandas as pd
    for phase, log_data in self.logs.items():
        if log_data:
            df = pd.DataFrame(log_data)
            df.to_csv(f"{self.save_dir}/{phase}_logs.csv", index=False)

BINNTrainer

Handles training BINN models using a raw PyTorch training loop.

Source code in binn/model/trainer.py
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class BINNTrainer:
    """
    Handles training BINN models using a raw PyTorch training loop.
    """

    def __init__(self, binn_model, save_dir: str = ""):
        """
        Args:
            binn_model: The BINN model instance to train.
            save_dir (str): Directory to save logs and/or checkpoints.
        """
        self.save_dir = save_dir
        self.network = binn_model
        self.logger = BINNLogger(save_dir=save_dir)


    def fit(
        self,
        dataloaders: dict,
        num_epochs: int = 30,
        learning_rate: float = 1e-4,
        checkpoint_path: str = None,
    ):
        """
        Train the BINN model using a standard PyTorch training loop.

        Args:
            dataloaders (dict): Dictionary containing:
                - "train": DataLoader for training data.
                - "val" (optional): DataLoader for validation data.
            num_epochs (int): Number of training epochs.
            learning_rate (float): Learning rate for the optimizer.
            checkpoint_path (str): Path to save model checkpoints (optional).
        """

        optimizer = torch.optim.Adam(self.network.parameters(), lr=learning_rate)

        for epoch in range(num_epochs):

            self.network.train()
            train_loss, train_accuracy = 0.0, 0.0

            for inputs, targets in dataloaders["train"]:
                inputs = inputs.to(self.network.device)
                targets = targets.to(self.network.device)

                optimizer.zero_grad()
                outputs = self.network(inputs)
                loss = F.cross_entropy(outputs, targets)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                train_accuracy += (torch.argmax(outputs, dim=1) == targets).float().mean().item()

            avg_train_loss = train_loss / len(dataloaders["train"])
            avg_train_accuracy = train_accuracy / len(dataloaders["train"])

            print(
                f"[Epoch {epoch+1}/{num_epochs}] "
                f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy:.4f}"
            )

            if "val" in dataloaders:
                self.network.eval()
                val_loss, val_accuracy = 0.0, 0.0
                with torch.no_grad():
                    for inputs, targets in dataloaders["val"]:
                        inputs = inputs.to(self.network.device)
                        targets = targets.to(self.network.device)

                        outputs = self.network(inputs)
                        loss = F.cross_entropy(outputs, targets)

                        val_loss += loss.item()
                        val_accuracy += (torch.argmax(outputs, dim=1) == targets).float().mean().item()

                avg_val_loss = val_loss / len(dataloaders["val"])
                avg_val_accuracy = val_accuracy / len(dataloaders["val"])

                print(
                    f"[Epoch {epoch+1}/{num_epochs}] "
                    f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_accuracy:.4f}"
                )


            if checkpoint_path:
                torch.save(self.network.state_dict(), f"{checkpoint_path}_epoch{epoch+1}.pt")

    def evaluate(self, dataloader):
        """
        Evaluate the BINN model on a dataset.

        Args:
            dataloader (DataLoader): DataLoader for the evaluation dataset.

        Returns:
            dict: A dictionary with 'loss' and 'accuracy' on the evaluation set.
        """
        self.network.eval()
        total_loss, total_accuracy = 0.0, 0.0

        with torch.no_grad():
            for inputs, targets in dataloader:
                inputs = inputs.to(self.network.device)
                targets = targets.to(self.network.device)

                outputs = self.network(inputs)
                loss = F.cross_entropy(outputs, targets)

                total_loss += loss.item()
                total_accuracy += (torch.argmax(outputs, dim=1) == targets).float().mean().item()

        avg_loss = total_loss / len(dataloader)
        avg_accuracy = total_accuracy / len(dataloader)

        return {"loss": avg_loss, "accuracy": avg_accuracy}

    def update_model(self, new_binn_model):
        self.binn_model = new_binn_model
        self.logger = BINNLogger(save_dir=self.save_dir)

__init__(binn_model, save_dir='')

Parameters:

Name Type Description Default
binn_model

The BINN model instance to train.

required
save_dir str

Directory to save logs and/or checkpoints.

''
Source code in binn/model/trainer.py
10
11
12
13
14
15
16
17
18
def __init__(self, binn_model, save_dir: str = ""):
    """
    Args:
        binn_model: The BINN model instance to train.
        save_dir (str): Directory to save logs and/or checkpoints.
    """
    self.save_dir = save_dir
    self.network = binn_model
    self.logger = BINNLogger(save_dir=save_dir)

evaluate(dataloader)

Evaluate the BINN model on a dataset.

Parameters:

Name Type Description Default
dataloader DataLoader

DataLoader for the evaluation dataset.

required

Returns:

Name Type Description
dict

A dictionary with 'loss' and 'accuracy' on the evaluation set.

Source code in binn/model/trainer.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def evaluate(self, dataloader):
    """
    Evaluate the BINN model on a dataset.

    Args:
        dataloader (DataLoader): DataLoader for the evaluation dataset.

    Returns:
        dict: A dictionary with 'loss' and 'accuracy' on the evaluation set.
    """
    self.network.eval()
    total_loss, total_accuracy = 0.0, 0.0

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(self.network.device)
            targets = targets.to(self.network.device)

            outputs = self.network(inputs)
            loss = F.cross_entropy(outputs, targets)

            total_loss += loss.item()
            total_accuracy += (torch.argmax(outputs, dim=1) == targets).float().mean().item()

    avg_loss = total_loss / len(dataloader)
    avg_accuracy = total_accuracy / len(dataloader)

    return {"loss": avg_loss, "accuracy": avg_accuracy}

fit(dataloaders, num_epochs=30, learning_rate=0.0001, checkpoint_path=None)

Train the BINN model using a standard PyTorch training loop.

Parameters:

Name Type Description Default
dataloaders dict

Dictionary containing: - "train": DataLoader for training data. - "val" (optional): DataLoader for validation data.

required
num_epochs int

Number of training epochs.

30
learning_rate float

Learning rate for the optimizer.

0.0001
checkpoint_path str

Path to save model checkpoints (optional).

None
Source code in binn/model/trainer.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def fit(
    self,
    dataloaders: dict,
    num_epochs: int = 30,
    learning_rate: float = 1e-4,
    checkpoint_path: str = None,
):
    """
    Train the BINN model using a standard PyTorch training loop.

    Args:
        dataloaders (dict): Dictionary containing:
            - "train": DataLoader for training data.
            - "val" (optional): DataLoader for validation data.
        num_epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for the optimizer.
        checkpoint_path (str): Path to save model checkpoints (optional).
    """

    optimizer = torch.optim.Adam(self.network.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):

        self.network.train()
        train_loss, train_accuracy = 0.0, 0.0

        for inputs, targets in dataloaders["train"]:
            inputs = inputs.to(self.network.device)
            targets = targets.to(self.network.device)

            optimizer.zero_grad()
            outputs = self.network(inputs)
            loss = F.cross_entropy(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_accuracy += (torch.argmax(outputs, dim=1) == targets).float().mean().item()

        avg_train_loss = train_loss / len(dataloaders["train"])
        avg_train_accuracy = train_accuracy / len(dataloaders["train"])

        print(
            f"[Epoch {epoch+1}/{num_epochs}] "
            f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy:.4f}"
        )

        if "val" in dataloaders:
            self.network.eval()
            val_loss, val_accuracy = 0.0, 0.0
            with torch.no_grad():
                for inputs, targets in dataloaders["val"]:
                    inputs = inputs.to(self.network.device)
                    targets = targets.to(self.network.device)

                    outputs = self.network(inputs)
                    loss = F.cross_entropy(outputs, targets)

                    val_loss += loss.item()
                    val_accuracy += (torch.argmax(outputs, dim=1) == targets).float().mean().item()

            avg_val_loss = val_loss / len(dataloaders["val"])
            avg_val_accuracy = val_accuracy / len(dataloaders["val"])

            print(
                f"[Epoch {epoch+1}/{num_epochs}] "
                f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_accuracy:.4f}"
            )


        if checkpoint_path:
            torch.save(self.network.state_dict(), f"{checkpoint_path}_epoch{epoch+1}.pt")