Skip to content

Models Module

energy_gnome.models

energy_gnome.models.E3NNRegressor

Bases: BaseRegressor

Source code in energy_gnome/models/e3nn/regressor.py
Python
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
class E3NNRegressor(BaseRegressor):
    def __init__(
        self,
        model_name: str,
        target_property: str,
        models_dir: Path | str = MODELS_DIR,
        figures_dir: Path | str = FIGURES_DIR,
    ):
        """
        Initialize the E3NNRegressor with directories for storing models and figures.

        This class extends `BaseRegressor` to implement an equivariant neural network (E3NN)
        for regression tasks. It sets up the necessary directory structure and configurations
        for training models.

        Args:
            model_name (str): Name of the model, used to create subdirectories.
            target_property (str): The target property the model is trained to predict.
            models_dir (Path | str, optional): Directory for storing trained model weights.
                                            Defaults to `MODELS_DIR` from config.
            figures_dir (Path | str, optional): Directory for saving figures and visualizations.
                                                Defaults to `FIGURES_DIR` from config.

        Attributes:
            _model_spec (str): Specification string used for model identification.
            l_max (int): Maximum order of the spherical harmonics used in the E3NN model (default is 2).
            r_max (int): Cutoff radius used in the E3NN model (default is 4).
            conv_layers (int): Number of nonlinearities (number of convolutions = layers + 1, default is 2).
        """

        self._model_spec = "model.e3nn_regressor." + target_property
        super().__init__(
            model_name=model_name,
            target_property=target_property,
            models_dir=models_dir,
            figures_dir=figures_dir,
        )
        self.l_max: int = 2
        self.r_max: int = 4
        self.conv_layers: int = 2

    def _find_model_states(self):
        models_states = []
        if any(self.models_dir.iterdir()):
            models_states = [f_ for f_ in self.models_dir.iterdir() if f_.match("*.torch")]
        return models_states

    def load_trained_models(self, state: str = "state_best"):
        """
        Load trained models from the model directory.

        This method searches for trained models by:
        1. Loading model settings from `.yaml` files matching `_model_spec`.
        2. Initializing models based on corresponding `.json` configuration files.
        3. Loading the model weights from `.torch` files.
        4. Storing the loaded models in `self.models`.

        Args:
            state (str, optional): The key used to extract model weights from the saved
                                state dictionary (e.g., `"state_best"`). Defaults to `"state_best"`.

        Returns:
            list[str]: A list of `.torch` model filenames that were found in the directory.
        """
        for yaml_path in self.models_dir.glob(f"*{self._model_spec}.yaml"):
            self.set_model_settings(yaml_file=yaml_path)

        i = 0
        loaded_models = []
        for model_path in self.models_dir.glob("*.torch"):
            if self._model_spec in model_path.name:
                model_setting_path = model_path.with_suffix(".json")
                if model_setting_path.exists():
                    try:
                        logger.info(f"Loading model with setting in {model_setting_path}")
                        logger.info(f"And weights in {model_path}")
                        model_setting = load_json(model_setting_path)

                        model = PeriodicNetwork(**model_setting)
                        model.load_state_dict(
                            torch.load(model_path, map_location=self.device, weights_only=True)[
                                state
                            ]
                        )
                        model.pool = True
                        self.models[f"model_{i}"] = model
                        loaded_models.append(model_path.name)
                        i += 1
                    except Exception as e:
                        logger.error(f"Error loading model {model_path.name}: {e}")
                else:
                    logger.warning(f"Missing JSON settings for model weights in {model_path.name}")

        return loaded_models

    def _load_model_setting(self, yaml_path):
        """
        Load model settings from a YAML file and assign corresponding attributes.

        This method loads settings from the specified YAML file and sets model attributes
        based on the values in `DEFAULT_E3NN_SETTINGS`. If an attribute is missing, a KeyError
        will be raised.

        Args:
            yaml_path (Path): Path to the YAML file containing the model settings.
        """
        try:
            settings = load_yaml(yaml_path)
            for att, _ in DEFAULT_E3NN_SETTINGS.items():
                setattr(self, att, settings[att])

            # Ensure that the device is set correctly
            if "cuda" in self.device and not torch.cuda.is_available():
                logger.warning(f"Models trained on {self.device} but only the CPU is available.")
                self.device = "cpu"
        except Exception as e:
            logger.error(f"Error loading model settings from {yaml_path}: {e}")
            raise

    def _save_model_settings(self):
        """
        Save the current model settings to a YAML file.

        This method saves the current values of the model's attributes (based on `DEFAULT_E3NN_SETTINGS`)
        to a YAML file in the models directory.

        The file is named based on the model specification (`self._model_spec`) with a `.yaml` extension.

        Raises:
            IOError: If the saving process fails.
        """
        try:
            settings_path = self.models_dir / (self._model_spec + ".yaml")
            settings = {att: getattr(self, att) for att, _ in DEFAULT_E3NN_SETTINGS.items()}

            logger.info(f"Saving model settings to {settings_path}")
            save_yaml(settings, settings_path)
        except Exception as e:
            logger.error(f"Error saving model settings to {settings_path}: {e}")
            raise

    def set_model_settings(self, yaml_file: Path | str | None = None, **kargs):
        """
        Set model settings either from a YAML file or provided keyword arguments.

        This method allows setting model settings from multiple sources:
        1. If a `yaml_file` is provided, it loads the settings from that file.
        2. If additional settings are provided as keyword arguments (`kargs`), they overwrite
        the default or loaded settings.

        Args:
            yaml_file (Path, str, optional): Path to the YAML file containing the model settings.
            kargs (dict, optional): Dictionary of model settings to override the default ones.

        """
        # Accessing model settings (YAML FILE)
        if yaml_file:
            self._load_model_setting(yaml_file)

        # Accessing model settings (DEFAULT or provided in kargs)
        for att, defvalue in DEFAULT_E3NN_SETTINGS.items():
            if att in kargs:
                # If a setting is provided via kargs, use it
                setattr(self, att, kargs[att])
            else:
                try:
                    # Check if the attribute already exists and is not None
                    att_exist = getattr(self, att)
                    # If the attribute exists, we verify it's not None (NaN check)
                    att_exist = att_exist == att_exist
                except AttributeError:
                    # If the attribute does not exist, it will be set to the default value
                    att_exist = False

                if not att_exist:
                    # If the attribute doesn't exist or is None, use the default value
                    setattr(self, att, defvalue)
                    logger.warning(f"Using default value {defvalue} for {att} setting")

        # If yaml_file was not provided or is in a different directory, save settings
        if yaml_file is None or os.path.dirname(str(yaml_file)) != str(self.models_dir):
            self._save_model_settings()

    def set_training_settings(self, n_epochs: int):
        """
        Set the number of epochs for training.

        This method sets the number of epochs for the model's training process.
        It is assumed that the training process will be carried out for the specified
        number of epochs.

        Args:
            n_epochs (int): The number of epochs for training.
                            It should be a positive integer.
        """
        self.n_epochs = n_epochs

    def set_optimizer_settings(self, lr: float, wd: float):
        """
        Set the optimizer settings, including learning rate and weight decay.

        This method sets the learning rate and weight decay for the optimizer, which
        will be used in the training process.

        Args:
            lr (float): The learning rate for the optimizer. It should be a positive float.
            wd (float): The weight decay (regularization) parameter for the optimizer.
                        It should be a non-negative float.
        """
        self.learning_rate = lr
        self.weight_decay = wd

    def featurize_db(self, dataset: pd.DataFrame) -> pd.DataFrame:
        """
        Featurize the given dataset by processing the CIF file paths and extracting
        structural and chemical information.

        This method reads the CIF files specified in the input dataset, extracts chemical
        information (such as formulae and species), and then generates featurized data suitable
        for model training. It also preserves the target property if it exists in the dataset.

        Args:
            dataset (pd.DataFrame): Input dataset containing the CIF file paths and optionally
                                    the target property.

        Returns:
            (pd.DataFrame): The featurized dataset, including chemical information and features
                        for model training. The dataset includes columns for 'structure', 'species',
                        'formula', and 'data' (featurized data).
        """
        if dataset.empty:
            return pd.DataFrame()

        # Initialize tqdm for apply functions
        tqdm.pandas()

        # Create a copy to avoid modifying the original dataset
        # database_nn = dataset.copy()
        database_nn = pd.DataFrame(
            {
                "structure": pd.Series(dtype="object"),
                "species": pd.Series(dtype="object"),
            }
        )

        # Read all structures in parallel with progress tracking
        database_nn["structure"] = dataset["cif_path"]
        database_nn["structure"] = database_nn["structure"].progress_apply(
            lambda path: read(Path(to_unix(path)))
        )

        # Extract formula and species info
        database_nn["formula"] = database_nn["structure"].progress_apply(
            lambda s: s.get_chemical_formula()
        )
        database_nn["species"] = database_nn["structure"].progress_apply(
            lambda s: list(set(s.get_chemical_symbols()))
        )

        # Preserve target property
        if self.target_property in dataset.columns:
            database_nn[self.target_property] = dataset[self.target_property]

        # Get encoding for featurization
        type_encoding, type_onehot, atomicmass_onehot = get_encoding()
        r_max = self.r_max  # cutoff radius

        # Apply featurization with progress bar
        database_nn["data"] = database_nn.progress_apply(
            lambda x: build_data(
                x,
                type_encoding,
                type_onehot,
                atomicmass_onehot,
                self.target_property if self.target_property in database_nn.columns else None,
                r_max,
                dtype=DEFAULT_DTYPE,
            ),
            axis=1,
        )

        return database_nn

    def create_dataloader(
        self,
        databases: BaseDatabase | list[BaseDatabase],
        subset: str | None = None,
        shuffle: bool = False,
        confidence_threshold: float = 0.5,
    ):
        """
        Format and return a PyTorch DataLoader for training, validation, or testing.

        This method prepares the dataloaders for PyTorch training by featurizing the input datasets
        and handling multiple types of databases. It ensures that shuffling is only applied to the
        training subset, and it filters out low-confidence samples from GNoMEDatabase.

        Args:
            databases (BaseDatabase | list[BaseDatabase]): A single instance or a list of `BaseDatabase` objects
                                                        containing processed data. If multiple databases are provided,
                                                        they will be concatenated.
            subset (str, optional): The specific data subset to use (`training`, `validation`, or `testing`).
                                    If `None`, all data will be used. Defaults to `None`.
            shuffle (bool, optional): Whether to shuffle the data. This is only applicable for the training set.
                                    Defaults to `False`.
            confidence_threshold (float, optional): The threshold for filtering out low-confidence entries in the
                                                    `GNoMEDatabase`. Defaults to `0.5`.

        Raises:
            ValueError: If shuffling is set to `True` for the validation or testing subset.

        Returns:
            (tuple): A tuple containing:
                - dataloader_db (DataLoader): The PyTorch DataLoader object containing the processed data.
                - mean_neighbors (float): The mean number of neighbors (calculated using the featurized data).
        """
        if subset != "training" and shuffle is True:
            logger.error("Shuffling is not supported for the validation and testing set!")
            raise ValueError("Shuffling is not supported for the validation and testing set!")

        if not isinstance(databases, list):
            databases = [databases]

        if isinstance(databases[0], GNoMEDatabase):
            df = databases[0].get_database("processed")
            df = df[df["classifier_mean"] > confidence_threshold]
        else:
            db_list = []
            for db in databases:
                db_list.append(db.load_regressor_data(subset))

            if len(db_list) > 1:
                df = pd.concat(db_list, ignore_index=True)
            else:
                df = db_list[0]
                df.reset_index(drop=True, inplace=True)

        featurized_db = self.featurize_db(df)

        data_values = featurized_db["data"].values
        dataloader_db = DataLoader(
            data_values,
            batch_size=self.batch_size,
            num_workers=0,
            pin_memory=True,
            shuffle=shuffle,
        )

        n_neighbors = get_neighbors(featurized_db)

        return dataloader_db, n_neighbors.mean()

    def get_optimizer_scheduler_loss(
        self,
        optimizer_class=torch.optim.AdamW,
        scheduler_class=torch.optim.lr_scheduler.ExponentialLR,
        scheduler_settings: dict[str, Any] = dict(gamma=0.96),
        loss_function=torch.nn.L1Loss,
    ):
        """
        Configure the optimizer, learning rate scheduler, and loss function for training.

        This method sets up the components required for model training, including the optimizer,
        learning rate scheduler, and loss function. The optimizer and scheduler are configured
        based on the provided class arguments, while the loss function is selected based on
        the callable function provided.

        Args:
            optimizer_class (torch.optim.Optimizer, optional): The optimizer class to use.
                                                            Defaults to `torch.optim.AdamW`.
            scheduler_class (torch.optim.lr_scheduler._LRScheduler, optional): The learning rate
                                                                scheduler class to use.
                                                                Defaults to `torch.optim.lr_scheduler.ExponentialLR`.
            scheduler_settings (dict, optional): A dictionary of settings for the learning rate scheduler.
                                                For example, `gamma` can be set to control the decay rate.
                                                Defaults to `{"gamma": 0.96}`.
            loss_function (callable, optional): The loss function to use for training.
                                                Defaults to `torch.nn.L1Loss`.

        Raises:
            ValueError: If `optimizer_class` is not a subclass of `torch.optim.Optimizer`.
            ValueError: If `scheduler_class` is not a subclass of `torch.optim.lr_scheduler._LRScheduler`.
            ValueError: If `loss_function` is not callable.
        """
        if not issubclass(optimizer_class, torch.optim.Optimizer):
            raise ValueError("optimizer_class must be a subclass of torch.optim.Optimizer.")
        if not issubclass(scheduler_class, torch.optim.lr_scheduler.LRScheduler):
            raise ValueError("scheduler_class must be a subclass of torch.optim.lr_scheduler.")
        if not callable(loss_function):
            raise ValueError("loss_function must be callable.")

        self.optimizer = optimizer_class
        self.scheduler = scheduler_class
        self.loss_function = loss_function
        self._scheduler_settings = scheduler_settings
        logger.info(f"Optimizer configured: {optimizer_class.__name__}")
        logger.info(f"Learning rate scheduler configured: {scheduler_class.__name__}")
        logger.info(f"Loss function configured: {loss_function.__name__}")

    def _build_regressor(self, num_neighbors: float):
        """
        Build and initialize the regressor models using specified settings.

        This method constructs a regressor model for each committer and stores the models
        in the `self.models` dictionary. Each model is initialized with unique settings,
        and the model configuration is saved as a JSON file for reproducibility. The
        models are built using a `PeriodicNetwork` with predefined settings, including atom
        embeddings, number of layers, and other hyperparameters.

        Args:
            num_neighbors (float): Scaling factor based on the typical number of neighbors
                                for the convolution operation. It influences the model's
                                sensitivity to local atomic environments.
        """
        out_dim = 1  # Predict a scalar output
        em_dim = 64

        models = {}

        for i in tqdm(range(self.n_committers), desc="models"):
            model_settings_path = self.models_dir / (self._model_spec + f".rep{i}.json")
            model_settings = dict(
                in_dim=118,  # dimension of one-hot encoding of atom type
                em_dim=em_dim,  # dimension of atom-type embedding
                irreps_in=str(em_dim)
                + "x0e",  # em_dim scalars (L=0 and even parity) on each atom to represent atom type
                irreps_out=str(out_dim) + "x0e",  # out_dim scalars (L=0 and even parity) to output
                irreps_node_attr=str(em_dim)
                + "x0e",  # em_dim scalars (L=0 and even parity) on each atom to represent atom type
                layers=self.conv_layers,  # number of nonlinearities (number of convolutions = layers + 1)
                mul=32,  # multiplicity of irreducible representations
                lmax=self.l_max,  # maximum order of spherical harmonics
                max_radius=self.r_max,  # cutoff radius for convolution
                num_neighbors=num_neighbors,  # scaling factor based on the typical number of neighbors
                reduce_output=True,  # whether or not to aggregate features of all atoms at the end
            )
            with open(model_settings_path, "w") as fp:
                json.dump(model_settings, fp)
            models[f"model_{i}"] = PeriodicNetwork(**model_settings)
        self.models: dict[str, torch.nn.Module] = models

    def compile_(
        self,
        num_neighbors: float,
        lr: float = DEFAULT_OPTIM_SETTINGS["lr"],
        wd: float = DEFAULT_OPTIM_SETTINGS["wd"],
        optimizer_class=torch.optim.AdamW,
        scheduler_class=torch.optim.lr_scheduler.ExponentialLR,
        scheduler_settings: dict[str, Any] = dict(gamma=0.99),
        loss_function=torch.nn.L1Loss,
    ):
        """
        Compile and configure the model for training, setting up necessary components
        such as the optimizer, learning rate scheduler, and loss function, and then
        builds the regressors.

        This method performs the following steps:
        1. Loads the optimizer settings (learning rate and weight decay).
        2. Configures the optimizer, learning rate scheduler, and loss function.
        3. Builds the regressor models based on the provided number of neighbors.

        Args:
            num_neighbors (float): The scaling factor based on the typical number of neighbors.
            lr (float, optional): The learning rate for the optimizer.
                                    Defaults to `DEFAULT_OPTIM_SETTINGS["lr"]`.
            wd (float, optional): The weight decay for the optimizer.
                                    Defaults to `DEFAULT_OPTIM_SETTINGS["wd"]`.
            optimizer_class (torch.optim.Optimizer, optional): The optimizer class to use.
                                                                Defaults to `torch.optim.AdamW`.
            scheduler_class (torch.optim.lr_scheduler._LRScheduler, optional): The learning rate scheduler class.
                                                                Defaults to `torch.optim.lr_scheduler.ExponentialLR`.
            scheduler_settings (dict, optional): The settings for the learning rate scheduler.
                                                    Defaults to `dict(gamma=0.99)`.
            loss_function (torch.nn.Module, optional): The loss function to use for training.
                                                        Defaults to `torch.nn.L1Loss`.

        Raises:
            ValueError: If `num_neighbors` is not a positive number.
        """
        if num_neighbors <= 0:
            raise ValueError("'num_neighbors' must be a positive number.")

        logger.info("[STEP 1] Loading settings")
        self.set_optimizer_settings(lr, wd)

        logger.info("[STEP 2] Configuring optimizer, lr_scheduler, and loss function for training")
        self.get_optimizer_scheduler_loss(
            optimizer_class,
            scheduler_class,
            scheduler_settings,
            loss_function,
        )

        logger.info("[STEP 3] Building regressors")
        self._build_regressor(num_neighbors)

    def fit(
        self,
        dataloader_train: DataLoader,
        dataloader_valid: DataLoader,
        n_epochs: int = DEFAULT_TRAINING_SETTINGS["n_epochs"],
        parallelize: bool = False,
    ):
        """
        Train the regressor models using the specified training and validation datasets.

        This method supports both sequential and parallelized training:

        - If `parallelize` is `True` and multiple GPUs are available, training is executed asynchronously.
        - If only one GPU is available, a warning is issued, and sequential training is used.
        - If no GPUs are available but the model is set to use CUDA, an error is raised.
        - Otherwise, models are trained sequentially.

        Args:
            dataloader_train (DataLoader): PyTorch DataLoader containing the training dataset.
            dataloader_valid (DataLoader): PyTorch DataLoader containing the validation dataset.
            n_epochs (int, optional): Number of training epochs. Defaults to `DEFAULT_TRAINING_SETTINGS["n_epochs"]`.
            parallelize (bool, optional): Whether to parallelize training across multiple GPUs. Defaults to `False`.

        Raises:
            RuntimeError: If CUDA is selected but no GPU is available.
        """
        self.set_training_settings(n_epochs)

        if parallelize and torch.cuda.is_available() and torch.cuda.device_count() > 1:
            self._run_async(dataloader_train, dataloader_valid)
        else:
            if parallelize and torch.cuda.is_available():
                logger.warning("Only one GPU available, the flag 'parallelize' is ignored.")
            elif not torch.cuda.is_available() and "cuda" in self.device:
                logger.error(
                    f"You are attempting to run the training on {self.device} but only CPUs are available, check your drivers' installation or model settings!"
                )
                raise RuntimeError(
                    f"You are attempting to run the training on {self.device} but only CPUs are available, check your drivers' installation or model settings!"
                )
            for i in range(self.n_committers):
                _train_model(
                    i,
                    self.device,
                    self._model_spec,
                    self.models[f"model_{i}"],
                    self.models_dir,
                    self.optimizer,
                    self.learning_rate,
                    self.weight_decay,
                    self.scheduler,
                    self._scheduler_settings,
                    self.loss_function,
                    self.n_epochs,
                    dataloader_train,
                    dataloader_valid,
                )

    # Run training asynchronously
    async def _multi(self, dataloader_train: DataLoader, dataloader_valid: DataLoader):
        """
        Run training asynchronously for multiple models.

        This method launches multiple training processes asynchronously,
        distributing the models across available GPU devices.

        Args:
            dataloader_train (DataLoader): Training dataset wrapped in a PyTorch DataLoader.
            dataloader_valid (DataLoader): Validation dataset wrapped in a PyTorch DataLoader.

        Raises:
            RuntimeError: If CUDA devices are not available but GPU training is attempted.

        Note:
            - Each model is assigned to a CUDA device in a round-robin fashion.
            - Uses asyncio to manage multiple concurrent training executions.
        """
        logger.info("[ASYNC] Launching training asynchronously...")
        loop = asyncio.get_event_loop()
        tasks = []

        for i in range(self.n_committers):
            logger.info(
                f"[ASYNC] Iteration {i+1}/{self.n_committers}: Launching training on cuda:{i}..."
            )
            task = loop.run_in_executor(
                None,
                _train_model,
                i,
                f"cuda:{i % torch.cuda.device_count()}",
                self._model_spec,
                self.models[f"model_{i}"],
                self.models_dir,
                self.optimizer,
                self.learning_rate,
                self.weight_decay,
                self.scheduler,
                self._scheduler_settings,
                self.loss_function,
                self.n_epochs,
                dataloader_train,
                dataloader_valid,
            )
            tasks.append(task)
        await asyncio.gather(*tasks)
        logger.info("[ASYNC] All training processes completed!")

    def _run_async(self, dataloader_train: DataLoader, dataloader_valid: DataLoader):
        """
        Execute the asynchronous training routine based on the execution environment.

        This method ensures compatibility between Jupyter notebooks and standalone scripts
        when running the `_multi` training process. It checks for an active event loop:

        - If running inside a Jupyter notebook, it schedules `_multi` as a background task.
        - If running in a script, it uses `asyncio.run()` to properly start the event loop.

        Args:
            dataloader_train: The training dataset wrapped in a PyTorch DataLoader.
            dataloader_valid: The validation dataset wrapped in a PyTorch DataLoader.

        Returns:
            If running in a Jupyter notebook, returns an `asyncio.Task` object.
            If running as a script, the method executes `_multi` synchronously and returns `None`.
        """
        try:
            if asyncio.get_running_loop():  # Running inside Jupyter
                return asyncio.create_task(
                    self._multi(dataloader_train, dataloader_valid)
                )  # Run as a background task
        except RuntimeError:  # Running as a script
            logger.info("[SCRIPT] Running in a script. Using asyncio.run().")
            asyncio.run(self._multi(dataloader_train, dataloader_valid))

    def plot_history(self):
        """
        Plot the training and validation loss history for each trained model.

        This method iterates through the models saved in the model directory,
        loads their training history, and generates a plot comparing training and
        validation loss over epochs. The plot is saved as both PNG and PDF files
        in the figures directory.

        The plot will include:
            - X-axis: Epochs (steps)
            - Y-axis: Loss values
            - Two lines: Training loss and Validation loss

        Saves the generated plots as:
            - model_name_training.png
            - model_name_training.pdf

        Uses Matplotlib to generate the plots and saves them in the configured figures directory.

        Raises:
            FileNotFoundError: If no models are found in the specified model directory.
            KeyError: If the model history does not contain expected keys like "history".
        """
        for f in os.listdir(self.models_dir):
            if f.endswith(".torch"):
                model_history_path = self.models_dir / f
                try:
                    # Load history
                    history = torch.load(
                        model_history_path, map_location=self.device, weights_only=True
                    )["history"]
                except KeyError:
                    logger.error(f"KeyError: 'history' not found in {model_history_path}")
                    raise KeyError(
                        f"Model history for {f} does not contain expected 'history' key."
                    )

                # If history is loaded, set flag to True
                models_found = True

                steps = [d["step"] + 1 for d in history]
                loss_train = [d["train"]["loss"] for d in history]
                loss_valid = [d["valid"]["loss"] for d in history]

                fig, ax = plt.subplots(figsize=(7, 3), dpi=150)
                ax.plot(steps, loss_train, "o-", label="Training")
                ax.plot(steps, loss_valid, "o-", label="Validation")
                ax.set_xlabel("epochs")
                ax.set_ylabel("loss")
                ax.legend(frameon=False)
                fig.savefig(
                    self.figures_dir / f.replace(".torch", "_training.png"),
                    dpi=330,
                    bbox_inches="tight",
                )
                fig.savefig(
                    self.figures_dir / f.replace(".torch", "_training.pdf"),
                    dpi=330,
                    bbox_inches="tight",
                )
                plt.show()

        # Raise error if no models are found
        if not models_found:
            logger.error("No models found in the specified model directory.")
            raise FileNotFoundError("No model files found in the models directory.")

    def evaluate(self, dataloader: DataLoader, return_df: bool = False):
        """
        Evaluate the performance of the regression model(s) on the provided dataset.

        This method runs inference on the given dataset and calculates the loss
        (L1 loss) for each model in the `self.models` list. It returns either a
        detailed DataFrame with predictions and losses or a dictionary of predictions
        for each model, depending on the `return_df` flag.

        Args:
            dataloader (DataLoader): The DataLoader object containing the dataset
                to be evaluated.
            return_df (bool): Whether to return the results as a DataFrame with
                predictions and losses (`True`), or a dictionary with per-model
                results (`False`). Default is `False`.

        Returns:
            (pd.DataFrame): If `return_df=True`, returns a pandas DataFrame where each column
                            corresponds to predictions and loss for each model.
                            The columns include:
                                - `true_value`: Ground truth values.
                                - `model_i_prediction`: Predictions from model `i`.
                                - `model_i_loss`: L1 loss for model `i`.

            (dict[str, pd.DataFrame]): If `return_df=False`, returns a dictionary where each key is
                    a model identifier (e.g., `model_0`, `model_1`, ...) and the value is a
                    DataFrame containing the following columns:
                        - `true_value`: Ground truth values.
                        - `prediction`: Predictions from the model.
                        - `loss`: L1 loss computed for each sample.
        """
        if return_df:
            prediction_nn = pd.DataFrame()
            prediction_nn["true_value"] = [
                item[0] for batch in dataloader for item in batch["target"].tolist()
            ]

            for i in tqdm(range(self.n_committers), desc="models"):
                prediction_nn[f"model_{i}_prediction"] = np.empty(
                    (len(dataloader.dataset), 1)
                ).tolist()
                prediction_nn[f"model_{i}_loss"] = 0.0

                self.models[f"model_{i}"].to(self.device)
                self.models[f"model_{i}"].eval()
                with torch.no_grad():
                    i0 = 0
                    for j, d in tqdm(enumerate(dataloader), total=len(dataloader)):
                        d.to(self.device)
                        output = self.models[f"model_{i}"](d)
                        loss = (
                            F.l1_loss(output, d.target, reduction="none")
                            .mean(dim=-1)
                            .cpu()
                            .numpy()
                        )
                        prediction_nn.loc[i0 : i0 + len(d.target) - 1, f"model_{i}_prediction"] = [
                            k for k in output.cpu().numpy()
                        ]
                        prediction_nn.loc[i0 : i0 + len(d.target) - 1, f"model_{i}_loss"] = loss
                        i0 += len(d.target)

        else:
            prediction_nn = {}
            for i in tqdm(range(self.n_committers), desc="models"):
                prediction_nn[f"model_{i}"] = pd.DataFrame()
                prediction_nn[f"model_{i}"]["loss"] = 0.0
                prediction_nn[f"model_{i}"]["true_value"] = [
                    item[0] for batch in dataloader for item in batch["target"].tolist()
                ]
                prediction_nn[f"model_{i}"]["prediction"] = np.empty(
                    (len(dataloader.dataset), 1)
                ).tolist()

                self.models[f"model_{i}"].to(self.device)
                self.models[f"model_{i}"].eval()
                with torch.no_grad():
                    i0 = 0
                    for j, d in tqdm(enumerate(dataloader), total=len(dataloader)):
                        d.to(self.device)
                        output = self.models[f"model_{i}"](d)
                        loss = (
                            F.l1_loss(output, d.target, reduction="none")
                            .mean(dim=-1)
                            .cpu()
                            .numpy()
                        )
                        prediction_nn[f"model_{i}"].loc[
                            i0 : i0 + len(d.target) - 1, "prediction"
                        ] = [k for k in output.cpu().numpy()]
                        prediction_nn[f"model_{i}"].loc[i0 : i0 + len(d.target) - 1, "loss"] = loss
                        i0 += len(d.target)

        return prediction_nn

    def plot_parity(
        self, predictions_dict: dict[str, pd.DataFrame], include_ensemble: bool = True
    ):
        """
        Plot a parity plot for model predictions and their comparison with true values.

        This method generates a scatter plot where the x-axis represents the true values,
        and the y-axis represents the predicted values from one or more models. It also
        includes a reference line (1:1 line) and error histograms as insets to visualize
        the prediction error distribution. Additionally, it calculates and annotates the R²
        value for each model's predictions and optionally for the ensemble average of all models.

        Args:
            predictions_dict (dict): A dictionary where keys are model names (e.g., 'model_1', 'model_2')
                and values are pandas DataFrames containing the `true_value` and `prediction` columns.
            include_ensemble (bool): If `True`, an ensemble prediction (mean of all model predictions)
                is included in the plot. Default is `True`.
        """
        all_predictions = []
        fig, ax = plt.subplots(figsize=(5, 3), dpi=300)

        colors = plt.cm.tab10.colors  # Get distinct colors for different models

        for i, (model, df) in enumerate(predictions_dict.items()):
            y_true = df["true_value"]
            y_predictions = df["prediction"]

            if include_ensemble:
                all_predictions.append(y_predictions)

            error = np.abs(y_predictions - y_true)
            r2 = r2_score(y_true, y_predictions)

            # Scatter plot with unique color
            ax.scatter(
                y_true,
                y_predictions,
                s=6,
                alpha=0.6,
                color=colors[i % len(colors)],
                label=model,
                zorder=1,
            )

            # Reference line (1:1 line)
            ax.axline(
                (np.mean(y_true), np.mean(y_true)), slope=1, lw=0.85, ls="--", color="k", zorder=2
            )

            # Add inset histogram
            if i == 0:  # Create inset only once
                axin = ax.inset_axes([0.65, 0.17, 0.3, 0.3])
            axin.hist(
                error, bins=int(np.sqrt(len(error))), alpha=0.6, color=colors[i % len(colors)]
            )
            axin.hist(error, bins=int(np.sqrt(len(error))), histtype="step", lw=1, color="black")

            # Annotate R² values dynamically
            ax.annotate(
                f"$R^2={r2:1.2f}$",
                xy=(0.05, 0.96 - (i * 0.07)),  # Adjust position dynamically
                xycoords="axes fraction",
                va="top",
                ha="left",
                fontsize=8,
                color=colors[i % len(colors)],
            )

        # Add ensemble prediction plot (if required)
        if include_ensemble:
            ensemble_predictions = np.mean(all_predictions, axis=0)
            ensemble_r2 = r2_score(y_true, ensemble_predictions)

            # Scatter plot for ensemble prediction
            ax.scatter(
                y_true,
                ensemble_predictions,
                s=6,
                alpha=0.6,
                color=colors[len(predictions_dict) % len(colors)],
                label="Ensemble",
                zorder=3,
            )

            # Annotate R² value for ensemble
            ax.annotate(
                f"Ensemble $R^2={ensemble_r2:1.2f}$",
                xy=(0.05, 0.96 - (len(predictions_dict) * 0.07)),  # Adjusted position for ensemble
                xycoords="axes fraction",
                va="top",
                ha="left",
                fontsize=8,
                color=colors[len(predictions_dict) % len(colors)],
            )

        # Set labels
        ax.set_xlabel("True value")
        ax.set_ylabel("Predicted value")

        # Set axis limits to ensure a square parity plot
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        min_ = min(min(xlim), min(ylim))
        max_ = max(max(xlim), max(ylim))
        ax.set_xlim(min_, max_)
        ax.set_ylim(min_, max_)

        # Move legend outside the plot
        ax.legend(loc="upper left", bbox_to_anchor=(1.05, 1), fontsize=8, borderaxespad=0.0)

        # Add title
        fig.suptitle("Parity Plot", fontsize=10)

        # Save figures
        fig.savefig(
            self.figures_dir / (self._model_spec + "_parity.png"), dpi=330, bbox_inches="tight"
        )
        fig.savefig(
            self.figures_dir / (self._model_spec + "_parity.pdf"), dpi=330, bbox_inches="tight"
        )

        # Show plot
        plt.show()

    def _evaluate_unknown(
        self,
        dataloader: DataLoader,
    ) -> pd.DataFrame:
        """
        Predicts the target property for candidate specialized materials using regressor models.

        Args:
            dataloader (DataLoader): A PyTorch DataLoader containing the dataset to be evaluated.

        Returns:
            (pd.DataFrame): A DataFrame containing the predictions and uncertainties for the dataset,
                            with one column for each regressor model and additional columns for
                            the mean (`regressor_mean`) and standard deviation (`regressor_std`)
                            of the predictions across all models.
        """
        prediction_nn = pd.DataFrame()

        for i in tqdm(range(self.n_committers), desc="models"):
            prediction_nn[f"regressor_{i}"] = np.empty((len(dataloader.dataset), 1)).tolist()

            self.models[f"model_{i}"].to(self.device)
            self.models[f"model_{i}"].eval()
            with torch.no_grad():
                i0 = 0
                for j, d in tqdm(enumerate(dataloader), total=len(dataloader)):
                    d.to(self.device)
                    output = self.models[f"model_{i}"](d)
                    # print([[k] for k in output.cpu().numpy()])
                    prediction_nn.loc[i0 : i0 + len(d.symbol) - 1, f"regressor_{i}"] = [
                        float(k) for k in output.cpu().numpy()
                    ]
                    i0 += len(d.symbol)

        prediction_nn["regressor_mean"] = prediction_nn[
            [f"regressor_{i}" for i in range(self.n_committers)]
        ].mean(axis=1)
        prediction_nn["regressor_std"] = prediction_nn[
            [f"regressor_{i}" for i in range(self.n_committers)]
        ].std(axis=1)

        return prediction_nn

    def predict(self, db: GNoMEDatabase, confidence_threshold: float = 0.5, save_final=True):
        """
        Predicts the target property for candidate specialized materials using regressor models,
        after filtering materials based on classifier committee confidence.

        Args:
            db (GNoMEDatabase): The database containing the materials and their properties.
            confidence_threshold (float, optional): The minimum classifier committee confidence
                                                    required to keep a material for prediction.
                                                    Defaults to `0.5`.
            save_final (bool, optional): Whether to save the final database with predictions.
                                        Defaults to `True`.

        Returns:
            (pd.DataFrame): A DataFrame containing the predictions, along with the true values and
                        classifier committee confidence scores for the screened materials.

        Notes:
            - The method filters the materials based on the classifier confidence, then uses the
            regressor models to predict the target property for the remaining materials.
            - If `save_final` is set to True, the predictions are saved to the database in the
            `final` stage.
        """
        logger.info(
            f"Discarding materials with classifier committee confidence threshold < {confidence_threshold}."
        )
        logger.info("Featurizing and loading database as `tg.loader.DataLoader`.")
        dataloader_db, _ = self.create_dataloader(db, confidence_threshold)
        logger.info("Predicting the target property for candidate specialized materials.")
        predictions = self._evaluate_unknown(dataloader_db)
        df = db.get_database("processed")
        screened = df[df["classifier_mean"] > confidence_threshold]
        predictions = pd.concat([screened.reset_index(), predictions], axis=1)

        if save_final:
            logger.info("Saving the final database.")
            db.databases["final"] = predictions.copy()
            db.save_database("final")

        return predictions

__init__(model_name, target_property, models_dir=MODELS_DIR, figures_dir=FIGURES_DIR)

This class extends BaseRegressor to implement an equivariant neural network (E3NN) for regression tasks. It sets up the necessary directory structure and configurations for training models.

Parameters:

Name Type Description Default
model_name str

Name of the model, used to create subdirectories.

required
target_property str

The target property the model is trained to predict.

required
models_dir Path | str

Directory for storing trained model weights. Defaults to MODELS_DIR from config.

MODELS_DIR
figures_dir Path | str

Directory for saving figures and visualizations. Defaults to FIGURES_DIR from config.

FIGURES_DIR

Attributes:

Name Type Description
_model_spec str

Specification string used for model identification.

l_max int

Maximum order of the spherical harmonics used in the E3NN model (default is 2).

r_max int

Cutoff radius used in the E3NN model (default is 4).

conv_layers int

Number of nonlinearities (number of convolutions = layers + 1, default is 2).

Source code in energy_gnome/models/e3nn/regressor.py
Python
def __init__(
    self,
    model_name: str,
    target_property: str,
    models_dir: Path | str = MODELS_DIR,
    figures_dir: Path | str = FIGURES_DIR,
):
    """
    Initialize the E3NNRegressor with directories for storing models and figures.

    This class extends `BaseRegressor` to implement an equivariant neural network (E3NN)
    for regression tasks. It sets up the necessary directory structure and configurations
    for training models.

    Args:
        model_name (str): Name of the model, used to create subdirectories.
        target_property (str): The target property the model is trained to predict.
        models_dir (Path | str, optional): Directory for storing trained model weights.
                                        Defaults to `MODELS_DIR` from config.
        figures_dir (Path | str, optional): Directory for saving figures and visualizations.
                                            Defaults to `FIGURES_DIR` from config.

    Attributes:
        _model_spec (str): Specification string used for model identification.
        l_max (int): Maximum order of the spherical harmonics used in the E3NN model (default is 2).
        r_max (int): Cutoff radius used in the E3NN model (default is 4).
        conv_layers (int): Number of nonlinearities (number of convolutions = layers + 1, default is 2).
    """

    self._model_spec = "model.e3nn_regressor." + target_property
    super().__init__(
        model_name=model_name,
        target_property=target_property,
        models_dir=models_dir,
        figures_dir=figures_dir,
    )
    self.l_max: int = 2
    self.r_max: int = 4
    self.conv_layers: int = 2

compile_(num_neighbors, lr=DEFAULT_OPTIM_SETTINGS['lr'], wd=DEFAULT_OPTIM_SETTINGS['wd'], optimizer_class=torch.optim.AdamW, scheduler_class=torch.optim.lr_scheduler.ExponentialLR, scheduler_settings=dict(gamma=0.99), loss_function=torch.nn.L1Loss)

Compile and configure the model for training, setting up necessary components such as the optimizer, learning rate scheduler, and loss function, and then builds the regressors.

This method performs the following steps: 1. Loads the optimizer settings (learning rate and weight decay). 2. Configures the optimizer, learning rate scheduler, and loss function. 3. Builds the regressor models based on the provided number of neighbors.

Parameters:

Name Type Description Default
num_neighbors float

The scaling factor based on the typical number of neighbors.

required
lr float

The learning rate for the optimizer. Defaults to DEFAULT_OPTIM_SETTINGS["lr"].

DEFAULT_OPTIM_SETTINGS['lr']
wd float

The weight decay for the optimizer. Defaults to DEFAULT_OPTIM_SETTINGS["wd"].

DEFAULT_OPTIM_SETTINGS['wd']
optimizer_class Optimizer

The optimizer class to use. Defaults to torch.optim.AdamW.

AdamW
scheduler_class _LRScheduler

The learning rate scheduler class. Defaults to torch.optim.lr_scheduler.ExponentialLR.

ExponentialLR
scheduler_settings dict

The settings for the learning rate scheduler. Defaults to dict(gamma=0.99).

dict(gamma=0.99)
loss_function Module

The loss function to use for training. Defaults to torch.nn.L1Loss.

L1Loss

Raises:

Type Description
ValueError

If num_neighbors is not a positive number.

Source code in energy_gnome/models/e3nn/regressor.py
Python
def compile_(
    self,
    num_neighbors: float,
    lr: float = DEFAULT_OPTIM_SETTINGS["lr"],
    wd: float = DEFAULT_OPTIM_SETTINGS["wd"],
    optimizer_class=torch.optim.AdamW,
    scheduler_class=torch.optim.lr_scheduler.ExponentialLR,
    scheduler_settings: dict[str, Any] = dict(gamma=0.99),
    loss_function=torch.nn.L1Loss,
):
    """
    Compile and configure the model for training, setting up necessary components
    such as the optimizer, learning rate scheduler, and loss function, and then
    builds the regressors.

    This method performs the following steps:
    1. Loads the optimizer settings (learning rate and weight decay).
    2. Configures the optimizer, learning rate scheduler, and loss function.
    3. Builds the regressor models based on the provided number of neighbors.

    Args:
        num_neighbors (float): The scaling factor based on the typical number of neighbors.
        lr (float, optional): The learning rate for the optimizer.
                                Defaults to `DEFAULT_OPTIM_SETTINGS["lr"]`.
        wd (float, optional): The weight decay for the optimizer.
                                Defaults to `DEFAULT_OPTIM_SETTINGS["wd"]`.
        optimizer_class (torch.optim.Optimizer, optional): The optimizer class to use.
                                                            Defaults to `torch.optim.AdamW`.
        scheduler_class (torch.optim.lr_scheduler._LRScheduler, optional): The learning rate scheduler class.
                                                            Defaults to `torch.optim.lr_scheduler.ExponentialLR`.
        scheduler_settings (dict, optional): The settings for the learning rate scheduler.
                                                Defaults to `dict(gamma=0.99)`.
        loss_function (torch.nn.Module, optional): The loss function to use for training.
                                                    Defaults to `torch.nn.L1Loss`.

    Raises:
        ValueError: If `num_neighbors` is not a positive number.
    """
    if num_neighbors <= 0:
        raise ValueError("'num_neighbors' must be a positive number.")

    logger.info("[STEP 1] Loading settings")
    self.set_optimizer_settings(lr, wd)

    logger.info("[STEP 2] Configuring optimizer, lr_scheduler, and loss function for training")
    self.get_optimizer_scheduler_loss(
        optimizer_class,
        scheduler_class,
        scheduler_settings,
        loss_function,
    )

    logger.info("[STEP 3] Building regressors")
    self._build_regressor(num_neighbors)

create_dataloader(databases, subset=None, shuffle=False, confidence_threshold=0.5)

Format and return a PyTorch DataLoader for training, validation, or testing.

This method prepares the dataloaders for PyTorch training by featurizing the input datasets and handling multiple types of databases. It ensures that shuffling is only applied to the training subset, and it filters out low-confidence samples from GNoMEDatabase.

Parameters:

Name Type Description Default
databases BaseDatabase | list[BaseDatabase]

A single instance or a list of BaseDatabase objects containing processed data. If multiple databases are provided, they will be concatenated.

required
subset str

The specific data subset to use (training, validation, or testing). If None, all data will be used. Defaults to None.

None
shuffle bool

Whether to shuffle the data. This is only applicable for the training set. Defaults to False.

False
confidence_threshold float

The threshold for filtering out low-confidence entries in the GNoMEDatabase. Defaults to 0.5.

0.5

Raises:

Type Description
ValueError

If shuffling is set to True for the validation or testing subset.

Returns:

Type Description
tuple

A tuple containing: - dataloader_db (DataLoader): The PyTorch DataLoader object containing the processed data. - mean_neighbors (float): The mean number of neighbors (calculated using the featurized data).

Source code in energy_gnome/models/e3nn/regressor.py
Python
def create_dataloader(
    self,
    databases: BaseDatabase | list[BaseDatabase],
    subset: str | None = None,
    shuffle: bool = False,
    confidence_threshold: float = 0.5,
):
    """
    Format and return a PyTorch DataLoader for training, validation, or testing.

    This method prepares the dataloaders for PyTorch training by featurizing the input datasets
    and handling multiple types of databases. It ensures that shuffling is only applied to the
    training subset, and it filters out low-confidence samples from GNoMEDatabase.

    Args:
        databases (BaseDatabase | list[BaseDatabase]): A single instance or a list of `BaseDatabase` objects
                                                    containing processed data. If multiple databases are provided,
                                                    they will be concatenated.
        subset (str, optional): The specific data subset to use (`training`, `validation`, or `testing`).
                                If `None`, all data will be used. Defaults to `None`.
        shuffle (bool, optional): Whether to shuffle the data. This is only applicable for the training set.
                                Defaults to `False`.
        confidence_threshold (float, optional): The threshold for filtering out low-confidence entries in the
                                                `GNoMEDatabase`. Defaults to `0.5`.

    Raises:
        ValueError: If shuffling is set to `True` for the validation or testing subset.

    Returns:
        (tuple): A tuple containing:
            - dataloader_db (DataLoader): The PyTorch DataLoader object containing the processed data.
            - mean_neighbors (float): The mean number of neighbors (calculated using the featurized data).
    """
    if subset != "training" and shuffle is True:
        logger.error("Shuffling is not supported for the validation and testing set!")
        raise ValueError("Shuffling is not supported for the validation and testing set!")

    if not isinstance(databases, list):
        databases = [databases]

    if isinstance(databases[0], GNoMEDatabase):
        df = databases[0].get_database("processed")
        df = df[df["classifier_mean"] > confidence_threshold]
    else:
        db_list = []
        for db in databases:
            db_list.append(db.load_regressor_data(subset))

        if len(db_list) > 1:
            df = pd.concat(db_list, ignore_index=True)
        else:
            df = db_list[0]
            df.reset_index(drop=True, inplace=True)

    featurized_db = self.featurize_db(df)

    data_values = featurized_db["data"].values
    dataloader_db = DataLoader(
        data_values,
        batch_size=self.batch_size,
        num_workers=0,
        pin_memory=True,
        shuffle=shuffle,
    )

    n_neighbors = get_neighbors(featurized_db)

    return dataloader_db, n_neighbors.mean()

evaluate(dataloader, return_df=False)

Evaluate the performance of the regression model(s) on the provided dataset.

This method runs inference on the given dataset and calculates the loss (L1 loss) for each model in the self.models list. It returns either a detailed DataFrame with predictions and losses or a dictionary of predictions for each model, depending on the return_df flag.

Parameters:

Name Type Description Default
dataloader DataLoader

The DataLoader object containing the dataset to be evaluated.

required
return_df bool

Whether to return the results as a DataFrame with predictions and losses (True), or a dictionary with per-model results (False). Default is False.

False

Returns:

Type Description
DataFrame

If return_df=True, returns a pandas DataFrame where each column corresponds to predictions and loss for each model. The columns include: - true_value: Ground truth values. - model_i_prediction: Predictions from model i. - model_i_loss: L1 loss for model i.

dict[str, DataFrame]

If return_df=False, returns a dictionary where each key is a model identifier (e.g., model_0, model_1, ...) and the value is a DataFrame containing the following columns: - true_value: Ground truth values. - prediction: Predictions from the model. - loss: L1 loss computed for each sample.

Source code in energy_gnome/models/e3nn/regressor.py
Python
def evaluate(self, dataloader: DataLoader, return_df: bool = False):
    """
    Evaluate the performance of the regression model(s) on the provided dataset.

    This method runs inference on the given dataset and calculates the loss
    (L1 loss) for each model in the `self.models` list. It returns either a
    detailed DataFrame with predictions and losses or a dictionary of predictions
    for each model, depending on the `return_df` flag.

    Args:
        dataloader (DataLoader): The DataLoader object containing the dataset
            to be evaluated.
        return_df (bool): Whether to return the results as a DataFrame with
            predictions and losses (`True`), or a dictionary with per-model
            results (`False`). Default is `False`.

    Returns:
        (pd.DataFrame): If `return_df=True`, returns a pandas DataFrame where each column
                        corresponds to predictions and loss for each model.
                        The columns include:
                            - `true_value`: Ground truth values.
                            - `model_i_prediction`: Predictions from model `i`.
                            - `model_i_loss`: L1 loss for model `i`.

        (dict[str, pd.DataFrame]): If `return_df=False`, returns a dictionary where each key is
                a model identifier (e.g., `model_0`, `model_1`, ...) and the value is a
                DataFrame containing the following columns:
                    - `true_value`: Ground truth values.
                    - `prediction`: Predictions from the model.
                    - `loss`: L1 loss computed for each sample.
    """
    if return_df:
        prediction_nn = pd.DataFrame()
        prediction_nn["true_value"] = [
            item[0] for batch in dataloader for item in batch["target"].tolist()
        ]

        for i in tqdm(range(self.n_committers), desc="models"):
            prediction_nn[f"model_{i}_prediction"] = np.empty(
                (len(dataloader.dataset), 1)
            ).tolist()
            prediction_nn[f"model_{i}_loss"] = 0.0

            self.models[f"model_{i}"].to(self.device)
            self.models[f"model_{i}"].eval()
            with torch.no_grad():
                i0 = 0
                for j, d in tqdm(enumerate(dataloader), total=len(dataloader)):
                    d.to(self.device)
                    output = self.models[f"model_{i}"](d)
                    loss = (
                        F.l1_loss(output, d.target, reduction="none")
                        .mean(dim=-1)
                        .cpu()
                        .numpy()
                    )
                    prediction_nn.loc[i0 : i0 + len(d.target) - 1, f"model_{i}_prediction"] = [
                        k for k in output.cpu().numpy()
                    ]
                    prediction_nn.loc[i0 : i0 + len(d.target) - 1, f"model_{i}_loss"] = loss
                    i0 += len(d.target)

    else:
        prediction_nn = {}
        for i in tqdm(range(self.n_committers), desc="models"):
            prediction_nn[f"model_{i}"] = pd.DataFrame()
            prediction_nn[f"model_{i}"]["loss"] = 0.0
            prediction_nn[f"model_{i}"]["true_value"] = [
                item[0] for batch in dataloader for item in batch["target"].tolist()
            ]
            prediction_nn[f"model_{i}"]["prediction"] = np.empty(
                (len(dataloader.dataset), 1)
            ).tolist()

            self.models[f"model_{i}"].to(self.device)
            self.models[f"model_{i}"].eval()
            with torch.no_grad():
                i0 = 0
                for j, d in tqdm(enumerate(dataloader), total=len(dataloader)):
                    d.to(self.device)
                    output = self.models[f"model_{i}"](d)
                    loss = (
                        F.l1_loss(output, d.target, reduction="none")
                        .mean(dim=-1)
                        .cpu()
                        .numpy()
                    )
                    prediction_nn[f"model_{i}"].loc[
                        i0 : i0 + len(d.target) - 1, "prediction"
                    ] = [k for k in output.cpu().numpy()]
                    prediction_nn[f"model_{i}"].loc[i0 : i0 + len(d.target) - 1, "loss"] = loss
                    i0 += len(d.target)

    return prediction_nn

featurize_db(dataset)

Featurize the given dataset by processing the CIF file paths and extracting structural and chemical information.

This method reads the CIF files specified in the input dataset, extracts chemical information (such as formulae and species), and then generates featurized data suitable for model training. It also preserves the target property if it exists in the dataset.

Parameters:

Name Type Description Default
dataset DataFrame

Input dataset containing the CIF file paths and optionally the target property.

required

Returns:

Type Description
DataFrame

The featurized dataset, including chemical information and features for model training. The dataset includes columns for 'structure', 'species', 'formula', and 'data' (featurized data).

Source code in energy_gnome/models/e3nn/regressor.py
Python
def featurize_db(self, dataset: pd.DataFrame) -> pd.DataFrame:
    """
    Featurize the given dataset by processing the CIF file paths and extracting
    structural and chemical information.

    This method reads the CIF files specified in the input dataset, extracts chemical
    information (such as formulae and species), and then generates featurized data suitable
    for model training. It also preserves the target property if it exists in the dataset.

    Args:
        dataset (pd.DataFrame): Input dataset containing the CIF file paths and optionally
                                the target property.

    Returns:
        (pd.DataFrame): The featurized dataset, including chemical information and features
                    for model training. The dataset includes columns for 'structure', 'species',
                    'formula', and 'data' (featurized data).
    """
    if dataset.empty:
        return pd.DataFrame()

    # Initialize tqdm for apply functions
    tqdm.pandas()

    # Create a copy to avoid modifying the original dataset
    # database_nn = dataset.copy()
    database_nn = pd.DataFrame(
        {
            "structure": pd.Series(dtype="object"),
            "species": pd.Series(dtype="object"),
        }
    )

    # Read all structures in parallel with progress tracking
    database_nn["structure"] = dataset["cif_path"]
    database_nn["structure"] = database_nn["structure"].progress_apply(
        lambda path: read(Path(to_unix(path)))
    )

    # Extract formula and species info
    database_nn["formula"] = database_nn["structure"].progress_apply(
        lambda s: s.get_chemical_formula()
    )
    database_nn["species"] = database_nn["structure"].progress_apply(
        lambda s: list(set(s.get_chemical_symbols()))
    )

    # Preserve target property
    if self.target_property in dataset.columns:
        database_nn[self.target_property] = dataset[self.target_property]

    # Get encoding for featurization
    type_encoding, type_onehot, atomicmass_onehot = get_encoding()
    r_max = self.r_max  # cutoff radius

    # Apply featurization with progress bar
    database_nn["data"] = database_nn.progress_apply(
        lambda x: build_data(
            x,
            type_encoding,
            type_onehot,
            atomicmass_onehot,
            self.target_property if self.target_property in database_nn.columns else None,
            r_max,
            dtype=DEFAULT_DTYPE,
        ),
        axis=1,
    )

    return database_nn

fit(dataloader_train, dataloader_valid, n_epochs=DEFAULT_TRAINING_SETTINGS['n_epochs'], parallelize=False)

Train the regressor models using the specified training and validation datasets.

This method supports both sequential and parallelized training:

  • If parallelize is True and multiple GPUs are available, training is executed asynchronously.
  • If only one GPU is available, a warning is issued, and sequential training is used.
  • If no GPUs are available but the model is set to use CUDA, an error is raised.
  • Otherwise, models are trained sequentially.

Parameters:

Name Type Description Default
dataloader_train DataLoader

PyTorch DataLoader containing the training dataset.

required
dataloader_valid DataLoader

PyTorch DataLoader containing the validation dataset.

required
n_epochs int

Number of training epochs. Defaults to DEFAULT_TRAINING_SETTINGS["n_epochs"].

DEFAULT_TRAINING_SETTINGS['n_epochs']
parallelize bool

Whether to parallelize training across multiple GPUs. Defaults to False.

False

Raises:

Type Description
RuntimeError

If CUDA is selected but no GPU is available.

Source code in energy_gnome/models/e3nn/regressor.py
Python
def fit(
    self,
    dataloader_train: DataLoader,
    dataloader_valid: DataLoader,
    n_epochs: int = DEFAULT_TRAINING_SETTINGS["n_epochs"],
    parallelize: bool = False,
):
    """
    Train the regressor models using the specified training and validation datasets.

    This method supports both sequential and parallelized training:

    - If `parallelize` is `True` and multiple GPUs are available, training is executed asynchronously.
    - If only one GPU is available, a warning is issued, and sequential training is used.
    - If no GPUs are available but the model is set to use CUDA, an error is raised.
    - Otherwise, models are trained sequentially.

    Args:
        dataloader_train (DataLoader): PyTorch DataLoader containing the training dataset.
        dataloader_valid (DataLoader): PyTorch DataLoader containing the validation dataset.
        n_epochs (int, optional): Number of training epochs. Defaults to `DEFAULT_TRAINING_SETTINGS["n_epochs"]`.
        parallelize (bool, optional): Whether to parallelize training across multiple GPUs. Defaults to `False`.

    Raises:
        RuntimeError: If CUDA is selected but no GPU is available.
    """
    self.set_training_settings(n_epochs)

    if parallelize and torch.cuda.is_available() and torch.cuda.device_count() > 1:
        self._run_async(dataloader_train, dataloader_valid)
    else:
        if parallelize and torch.cuda.is_available():
            logger.warning("Only one GPU available, the flag 'parallelize' is ignored.")
        elif not torch.cuda.is_available() and "cuda" in self.device:
            logger.error(
                f"You are attempting to run the training on {self.device} but only CPUs are available, check your drivers' installation or model settings!"
            )
            raise RuntimeError(
                f"You are attempting to run the training on {self.device} but only CPUs are available, check your drivers' installation or model settings!"
            )
        for i in range(self.n_committers):
            _train_model(
                i,
                self.device,
                self._model_spec,
                self.models[f"model_{i}"],
                self.models_dir,
                self.optimizer,
                self.learning_rate,
                self.weight_decay,
                self.scheduler,
                self._scheduler_settings,
                self.loss_function,
                self.n_epochs,
                dataloader_train,
                dataloader_valid,
            )

get_optimizer_scheduler_loss(optimizer_class=torch.optim.AdamW, scheduler_class=torch.optim.lr_scheduler.ExponentialLR, scheduler_settings=dict(gamma=0.96), loss_function=torch.nn.L1Loss)

Configure the optimizer, learning rate scheduler, and loss function for training.

This method sets up the components required for model training, including the optimizer, learning rate scheduler, and loss function. The optimizer and scheduler are configured based on the provided class arguments, while the loss function is selected based on the callable function provided.

Parameters:

Name Type Description Default
optimizer_class Optimizer

The optimizer class to use. Defaults to torch.optim.AdamW.

AdamW
scheduler_class _LRScheduler

The learning rate scheduler class to use. Defaults to torch.optim.lr_scheduler.ExponentialLR.

ExponentialLR
scheduler_settings dict

A dictionary of settings for the learning rate scheduler. For example, gamma can be set to control the decay rate. Defaults to {"gamma": 0.96}.

dict(gamma=0.96)
loss_function callable

The loss function to use for training. Defaults to torch.nn.L1Loss.

L1Loss

Raises:

Type Description
ValueError

If optimizer_class is not a subclass of torch.optim.Optimizer.

ValueError

If scheduler_class is not a subclass of torch.optim.lr_scheduler._LRScheduler.

ValueError

If loss_function is not callable.

Source code in energy_gnome/models/e3nn/regressor.py
Python
def get_optimizer_scheduler_loss(
    self,
    optimizer_class=torch.optim.AdamW,
    scheduler_class=torch.optim.lr_scheduler.ExponentialLR,
    scheduler_settings: dict[str, Any] = dict(gamma=0.96),
    loss_function=torch.nn.L1Loss,
):
    """
    Configure the optimizer, learning rate scheduler, and loss function for training.

    This method sets up the components required for model training, including the optimizer,
    learning rate scheduler, and loss function. The optimizer and scheduler are configured
    based on the provided class arguments, while the loss function is selected based on
    the callable function provided.

    Args:
        optimizer_class (torch.optim.Optimizer, optional): The optimizer class to use.
                                                        Defaults to `torch.optim.AdamW`.
        scheduler_class (torch.optim.lr_scheduler._LRScheduler, optional): The learning rate
                                                            scheduler class to use.
                                                            Defaults to `torch.optim.lr_scheduler.ExponentialLR`.
        scheduler_settings (dict, optional): A dictionary of settings for the learning rate scheduler.
                                            For example, `gamma` can be set to control the decay rate.
                                            Defaults to `{"gamma": 0.96}`.
        loss_function (callable, optional): The loss function to use for training.
                                            Defaults to `torch.nn.L1Loss`.

    Raises:
        ValueError: If `optimizer_class` is not a subclass of `torch.optim.Optimizer`.
        ValueError: If `scheduler_class` is not a subclass of `torch.optim.lr_scheduler._LRScheduler`.
        ValueError: If `loss_function` is not callable.
    """
    if not issubclass(optimizer_class, torch.optim.Optimizer):
        raise ValueError("optimizer_class must be a subclass of torch.optim.Optimizer.")
    if not issubclass(scheduler_class, torch.optim.lr_scheduler.LRScheduler):
        raise ValueError("scheduler_class must be a subclass of torch.optim.lr_scheduler.")
    if not callable(loss_function):
        raise ValueError("loss_function must be callable.")

    self.optimizer = optimizer_class
    self.scheduler = scheduler_class
    self.loss_function = loss_function
    self._scheduler_settings = scheduler_settings
    logger.info(f"Optimizer configured: {optimizer_class.__name__}")
    logger.info(f"Learning rate scheduler configured: {scheduler_class.__name__}")
    logger.info(f"Loss function configured: {loss_function.__name__}")

load_trained_models(state='state_best')

Load trained models from the model directory.

This method searches for trained models by: 1. Loading model settings from .yaml files matching _model_spec. 2. Initializing models based on corresponding .json configuration files. 3. Loading the model weights from .torch files. 4. Storing the loaded models in self.models.

Parameters:

Name Type Description Default
state str

The key used to extract model weights from the saved state dictionary (e.g., "state_best"). Defaults to "state_best".

'state_best'

Returns:

Type Description

list[str]: A list of .torch model filenames that were found in the directory.

Source code in energy_gnome/models/e3nn/regressor.py
Python
def load_trained_models(self, state: str = "state_best"):
    """
    Load trained models from the model directory.

    This method searches for trained models by:
    1. Loading model settings from `.yaml` files matching `_model_spec`.
    2. Initializing models based on corresponding `.json` configuration files.
    3. Loading the model weights from `.torch` files.
    4. Storing the loaded models in `self.models`.

    Args:
        state (str, optional): The key used to extract model weights from the saved
                            state dictionary (e.g., `"state_best"`). Defaults to `"state_best"`.

    Returns:
        list[str]: A list of `.torch` model filenames that were found in the directory.
    """
    for yaml_path in self.models_dir.glob(f"*{self._model_spec}.yaml"):
        self.set_model_settings(yaml_file=yaml_path)

    i = 0
    loaded_models = []
    for model_path in self.models_dir.glob("*.torch"):
        if self._model_spec in model_path.name:
            model_setting_path = model_path.with_suffix(".json")
            if model_setting_path.exists():
                try:
                    logger.info(f"Loading model with setting in {model_setting_path}")
                    logger.info(f"And weights in {model_path}")
                    model_setting = load_json(model_setting_path)

                    model = PeriodicNetwork(**model_setting)
                    model.load_state_dict(
                        torch.load(model_path, map_location=self.device, weights_only=True)[
                            state
                        ]
                    )
                    model.pool = True
                    self.models[f"model_{i}"] = model
                    loaded_models.append(model_path.name)
                    i += 1
                except Exception as e:
                    logger.error(f"Error loading model {model_path.name}: {e}")
            else:
                logger.warning(f"Missing JSON settings for model weights in {model_path.name}")

    return loaded_models

plot_history()

Plot the training and validation loss history for each trained model.

This method iterates through the models saved in the model directory, loads their training history, and generates a plot comparing training and validation loss over epochs. The plot is saved as both PNG and PDF files in the figures directory.

The plot will include
  • X-axis: Epochs (steps)
  • Y-axis: Loss values
  • Two lines: Training loss and Validation loss
Saves the generated plots as
  • model_name_training.png
  • model_name_training.pdf

Uses Matplotlib to generate the plots and saves them in the configured figures directory.

Raises:

Type Description
FileNotFoundError

If no models are found in the specified model directory.

KeyError

If the model history does not contain expected keys like "history".

Source code in energy_gnome/models/e3nn/regressor.py
Python
def plot_history(self):
    """
    Plot the training and validation loss history for each trained model.

    This method iterates through the models saved in the model directory,
    loads their training history, and generates a plot comparing training and
    validation loss over epochs. The plot is saved as both PNG and PDF files
    in the figures directory.

    The plot will include:
        - X-axis: Epochs (steps)
        - Y-axis: Loss values
        - Two lines: Training loss and Validation loss

    Saves the generated plots as:
        - model_name_training.png
        - model_name_training.pdf

    Uses Matplotlib to generate the plots and saves them in the configured figures directory.

    Raises:
        FileNotFoundError: If no models are found in the specified model directory.
        KeyError: If the model history does not contain expected keys like "history".
    """
    for f in os.listdir(self.models_dir):
        if f.endswith(".torch"):
            model_history_path = self.models_dir / f
            try:
                # Load history
                history = torch.load(
                    model_history_path, map_location=self.device, weights_only=True
                )["history"]
            except KeyError:
                logger.error(f"KeyError: 'history' not found in {model_history_path}")
                raise KeyError(
                    f"Model history for {f} does not contain expected 'history' key."
                )

            # If history is loaded, set flag to True
            models_found = True

            steps = [d["step"] + 1 for d in history]
            loss_train = [d["train"]["loss"] for d in history]
            loss_valid = [d["valid"]["loss"] for d in history]

            fig, ax = plt.subplots(figsize=(7, 3), dpi=150)
            ax.plot(steps, loss_train, "o-", label="Training")
            ax.plot(steps, loss_valid, "o-", label="Validation")
            ax.set_xlabel("epochs")
            ax.set_ylabel("loss")
            ax.legend(frameon=False)
            fig.savefig(
                self.figures_dir / f.replace(".torch", "_training.png"),
                dpi=330,
                bbox_inches="tight",
            )
            fig.savefig(
                self.figures_dir / f.replace(".torch", "_training.pdf"),
                dpi=330,
                bbox_inches="tight",
            )
            plt.show()

    # Raise error if no models are found
    if not models_found:
        logger.error("No models found in the specified model directory.")
        raise FileNotFoundError("No model files found in the models directory.")

plot_parity(predictions_dict, include_ensemble=True)

Plot a parity plot for model predictions and their comparison with true values.

This method generates a scatter plot where the x-axis represents the true values, and the y-axis represents the predicted values from one or more models. It also includes a reference line (1:1 line) and error histograms as insets to visualize the prediction error distribution. Additionally, it calculates and annotates the R² value for each model's predictions and optionally for the ensemble average of all models.

Parameters:

Name Type Description Default
predictions_dict dict

A dictionary where keys are model names (e.g., 'model_1', 'model_2') and values are pandas DataFrames containing the true_value and prediction columns.

required
include_ensemble bool

If True, an ensemble prediction (mean of all model predictions) is included in the plot. Default is True.

True
Source code in energy_gnome/models/e3nn/regressor.py
Python
def plot_parity(
    self, predictions_dict: dict[str, pd.DataFrame], include_ensemble: bool = True
):
    """
    Plot a parity plot for model predictions and their comparison with true values.

    This method generates a scatter plot where the x-axis represents the true values,
    and the y-axis represents the predicted values from one or more models. It also
    includes a reference line (1:1 line) and error histograms as insets to visualize
    the prediction error distribution. Additionally, it calculates and annotates the R²
    value for each model's predictions and optionally for the ensemble average of all models.

    Args:
        predictions_dict (dict): A dictionary where keys are model names (e.g., 'model_1', 'model_2')
            and values are pandas DataFrames containing the `true_value` and `prediction` columns.
        include_ensemble (bool): If `True`, an ensemble prediction (mean of all model predictions)
            is included in the plot. Default is `True`.
    """
    all_predictions = []
    fig, ax = plt.subplots(figsize=(5, 3), dpi=300)

    colors = plt.cm.tab10.colors  # Get distinct colors for different models

    for i, (model, df) in enumerate(predictions_dict.items()):
        y_true = df["true_value"]
        y_predictions = df["prediction"]

        if include_ensemble:
            all_predictions.append(y_predictions)

        error = np.abs(y_predictions - y_true)
        r2 = r2_score(y_true, y_predictions)

        # Scatter plot with unique color
        ax.scatter(
            y_true,
            y_predictions,
            s=6,
            alpha=0.6,
            color=colors[i % len(colors)],
            label=model,
            zorder=1,
        )

        # Reference line (1:1 line)
        ax.axline(
            (np.mean(y_true), np.mean(y_true)), slope=1, lw=0.85, ls="--", color="k", zorder=2
        )

        # Add inset histogram
        if i == 0:  # Create inset only once
            axin = ax.inset_axes([0.65, 0.17, 0.3, 0.3])
        axin.hist(
            error, bins=int(np.sqrt(len(error))), alpha=0.6, color=colors[i % len(colors)]
        )
        axin.hist(error, bins=int(np.sqrt(len(error))), histtype="step", lw=1, color="black")

        # Annotate R² values dynamically
        ax.annotate(
            f"$R^2={r2:1.2f}$",
            xy=(0.05, 0.96 - (i * 0.07)),  # Adjust position dynamically
            xycoords="axes fraction",
            va="top",
            ha="left",
            fontsize=8,
            color=colors[i % len(colors)],
        )

    # Add ensemble prediction plot (if required)
    if include_ensemble:
        ensemble_predictions = np.mean(all_predictions, axis=0)
        ensemble_r2 = r2_score(y_true, ensemble_predictions)

        # Scatter plot for ensemble prediction
        ax.scatter(
            y_true,
            ensemble_predictions,
            s=6,
            alpha=0.6,
            color=colors[len(predictions_dict) % len(colors)],
            label="Ensemble",
            zorder=3,
        )

        # Annotate R² value for ensemble
        ax.annotate(
            f"Ensemble $R^2={ensemble_r2:1.2f}$",
            xy=(0.05, 0.96 - (len(predictions_dict) * 0.07)),  # Adjusted position for ensemble
            xycoords="axes fraction",
            va="top",
            ha="left",
            fontsize=8,
            color=colors[len(predictions_dict) % len(colors)],
        )

    # Set labels
    ax.set_xlabel("True value")
    ax.set_ylabel("Predicted value")

    # Set axis limits to ensure a square parity plot
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    min_ = min(min(xlim), min(ylim))
    max_ = max(max(xlim), max(ylim))
    ax.set_xlim(min_, max_)
    ax.set_ylim(min_, max_)

    # Move legend outside the plot
    ax.legend(loc="upper left", bbox_to_anchor=(1.05, 1), fontsize=8, borderaxespad=0.0)

    # Add title
    fig.suptitle("Parity Plot", fontsize=10)

    # Save figures
    fig.savefig(
        self.figures_dir / (self._model_spec + "_parity.png"), dpi=330, bbox_inches="tight"
    )
    fig.savefig(
        self.figures_dir / (self._model_spec + "_parity.pdf"), dpi=330, bbox_inches="tight"
    )

    # Show plot
    plt.show()

predict(db, confidence_threshold=0.5, save_final=True)

Predicts the target property for candidate specialized materials using regressor models, after filtering materials based on classifier committee confidence.

Parameters:

Name Type Description Default
db GNoMEDatabase

The database containing the materials and their properties.

required
confidence_threshold float

The minimum classifier committee confidence required to keep a material for prediction. Defaults to 0.5.

0.5
save_final bool

Whether to save the final database with predictions. Defaults to True.

True

Returns:

Type Description
DataFrame

A DataFrame containing the predictions, along with the true values and classifier committee confidence scores for the screened materials.

Notes
  • The method filters the materials based on the classifier confidence, then uses the regressor models to predict the target property for the remaining materials.
  • If save_final is set to True, the predictions are saved to the database in the final stage.
Source code in energy_gnome/models/e3nn/regressor.py
Python
def predict(self, db: GNoMEDatabase, confidence_threshold: float = 0.5, save_final=True):
    """
    Predicts the target property for candidate specialized materials using regressor models,
    after filtering materials based on classifier committee confidence.

    Args:
        db (GNoMEDatabase): The database containing the materials and their properties.
        confidence_threshold (float, optional): The minimum classifier committee confidence
                                                required to keep a material for prediction.
                                                Defaults to `0.5`.
        save_final (bool, optional): Whether to save the final database with predictions.
                                    Defaults to `True`.

    Returns:
        (pd.DataFrame): A DataFrame containing the predictions, along with the true values and
                    classifier committee confidence scores for the screened materials.

    Notes:
        - The method filters the materials based on the classifier confidence, then uses the
        regressor models to predict the target property for the remaining materials.
        - If `save_final` is set to True, the predictions are saved to the database in the
        `final` stage.
    """
    logger.info(
        f"Discarding materials with classifier committee confidence threshold < {confidence_threshold}."
    )
    logger.info("Featurizing and loading database as `tg.loader.DataLoader`.")
    dataloader_db, _ = self.create_dataloader(db, confidence_threshold)
    logger.info("Predicting the target property for candidate specialized materials.")
    predictions = self._evaluate_unknown(dataloader_db)
    df = db.get_database("processed")
    screened = df[df["classifier_mean"] > confidence_threshold]
    predictions = pd.concat([screened.reset_index(), predictions], axis=1)

    if save_final:
        logger.info("Saving the final database.")
        db.databases["final"] = predictions.copy()
        db.save_database("final")

    return predictions

set_model_settings(yaml_file=None, **kargs)

Set model settings either from a YAML file or provided keyword arguments.

This method allows setting model settings from multiple sources: 1. If a yaml_file is provided, it loads the settings from that file. 2. If additional settings are provided as keyword arguments (kargs), they overwrite the default or loaded settings.

Parameters:

Name Type Description Default
yaml_file (Path, str)

Path to the YAML file containing the model settings.

None
kargs dict

Dictionary of model settings to override the default ones.

{}
Source code in energy_gnome/models/e3nn/regressor.py
Python
def set_model_settings(self, yaml_file: Path | str | None = None, **kargs):
    """
    Set model settings either from a YAML file or provided keyword arguments.

    This method allows setting model settings from multiple sources:
    1. If a `yaml_file` is provided, it loads the settings from that file.
    2. If additional settings are provided as keyword arguments (`kargs`), they overwrite
    the default or loaded settings.

    Args:
        yaml_file (Path, str, optional): Path to the YAML file containing the model settings.
        kargs (dict, optional): Dictionary of model settings to override the default ones.

    """
    # Accessing model settings (YAML FILE)
    if yaml_file:
        self._load_model_setting(yaml_file)

    # Accessing model settings (DEFAULT or provided in kargs)
    for att, defvalue in DEFAULT_E3NN_SETTINGS.items():
        if att in kargs:
            # If a setting is provided via kargs, use it
            setattr(self, att, kargs[att])
        else:
            try:
                # Check if the attribute already exists and is not None
                att_exist = getattr(self, att)
                # If the attribute exists, we verify it's not None (NaN check)
                att_exist = att_exist == att_exist
            except AttributeError:
                # If the attribute does not exist, it will be set to the default value
                att_exist = False

            if not att_exist:
                # If the attribute doesn't exist or is None, use the default value
                setattr(self, att, defvalue)
                logger.warning(f"Using default value {defvalue} for {att} setting")

    # If yaml_file was not provided or is in a different directory, save settings
    if yaml_file is None or os.path.dirname(str(yaml_file)) != str(self.models_dir):
        self._save_model_settings()

set_optimizer_settings(lr, wd)

Set the optimizer settings, including learning rate and weight decay.

This method sets the learning rate and weight decay for the optimizer, which will be used in the training process.

Parameters:

Name Type Description Default
lr float

The learning rate for the optimizer. It should be a positive float.

required
wd float

The weight decay (regularization) parameter for the optimizer. It should be a non-negative float.

required
Source code in energy_gnome/models/e3nn/regressor.py
Python
def set_optimizer_settings(self, lr: float, wd: float):
    """
    Set the optimizer settings, including learning rate and weight decay.

    This method sets the learning rate and weight decay for the optimizer, which
    will be used in the training process.

    Args:
        lr (float): The learning rate for the optimizer. It should be a positive float.
        wd (float): The weight decay (regularization) parameter for the optimizer.
                    It should be a non-negative float.
    """
    self.learning_rate = lr
    self.weight_decay = wd

set_training_settings(n_epochs)

Set the number of epochs for training.

This method sets the number of epochs for the model's training process. It is assumed that the training process will be carried out for the specified number of epochs.

Parameters:

Name Type Description Default
n_epochs int

The number of epochs for training. It should be a positive integer.

required
Source code in energy_gnome/models/e3nn/regressor.py
Python
def set_training_settings(self, n_epochs: int):
    """
    Set the number of epochs for training.

    This method sets the number of epochs for the model's training process.
    It is assumed that the training process will be carried out for the specified
    number of epochs.

    Args:
        n_epochs (int): The number of epochs for training.
                        It should be a positive integer.
    """
    self.n_epochs = n_epochs

energy_gnome.models.GBDTClassifier

Bases: BaseClassifier

Source code in energy_gnome/models/gbdt/classifier.py
Python
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
class GBDTClassifier(BaseClassifier):
    def __init__(
        self,
        model_name: str,
        models_dir: Path | str = MODELS_DIR,
        figures_dir: Path | str = FIGURES_DIR,
    ):
        """
        Initialize the GBDTClassifier with directories for storing models and figures.

        This class extends `BaseClassifier` to implement a gradient boosted decision tree (GBDT)
        for classification tasks. It sets up the necessary directory structure and configurations
        for training models.

        Args:
            model_name (str): Name of the model, used to create subdirectories.
            models_dir (Path | str, optional): Directory for storing trained model weights.
                                            Defaults to MODELS_DIR from config.
            figures_dir (Path | str, optional): Directory for saving figures and visualizations.
                                                Defaults to FIGURES_DIR from config.

        Attributes:
            _model_spec (str): Specification string used for model identification.
        """
        self._model_spec = "model.gbdt_classifier"
        super().__init__(model_name=model_name, models_dir=models_dir, figures_dir=figures_dir)

    def _find_model_states(self):
        models_states = []
        if any(self.models_dir.iterdir()):
            models_states = [f_ for f_ in self.models_dir.iterdir() if f_.match("*.pkl")]
        return models_states

    def load_trained_models(self) -> list[str]:
        """
        Load trained models from the model directory.

        This method searches for trained models by:
        1. Loading model settings from `.yaml` files matching `_model_spec`.
        2. Loading model weights from `.pkl` files.
        3. Storing the loaded models in `self.models`.

        Returns:
            list[str]: A list of `.pkl` model filenames that were found in the directory.
        """
        # Load model settings from YAML files matching the model spec
        for yaml_path in self.models_dir.glob(f"*{self._model_spec}*.yaml"):
            self.set_model_settings(yaml_file=yaml_path)

        i = 0
        loaded_models = []
        for model_path in self.models_dir.glob("*.pkl"):
            try:
                logger.info(f"Loading model from {model_path}")
                with open(model_path, "rb") as file_:
                    model = pickle.load(file_)
                    model.pool = True
                    self.models[f"model_{i}"] = model
                    loaded_models.append(model_path.name)
                    i += 1
            except Exception as e:
                logger.error(f"Error loading model {model_path.name}: {e}")

        return loaded_models

    def _load_model_setting(self, yaml_path):
        """
        Load model settings from a YAML file and assign corresponding attributes.

        This method loads settings from the specified YAML file and sets model attributes
        based on the values in `DEFAULT_GBDT_SETTINGS`. If an attribute is missing, a KeyError
        will be raised.

        Args:
            yaml_path (Path): Path to the YAML file containing the model settings.
        """
        settings = load_yaml(yaml_path)
        for att, _ in DEFAULT_GBDT_SETTINGS.items():
            setattr(self, att, settings[att])

    def _save_model_settings(self):
        """
        Save the current model settings to a YAML file.

        This method saves the current values of the model's attributes (based on `DEFAULT_GBDT_SETTINGS`)
        to a YAML file in the models directory.

        The file is named based on the model specification (`self._model_spec`) with a `.yaml` extension.

        Raises:
            IOError: If the saving process fails.
        """
        settigns_path = self.models_dir / (self._model_spec + ".yaml")
        settings = {}
        for att, _ in DEFAULT_GBDT_SETTINGS.items():
            settings[att] = getattr(self, att)
        logger.info(f"saving models 'general' settings in {settigns_path}")
        save_yaml(settings, settigns_path)

    def set_model_settings(self, yaml_file: Path | str | None = None, **kargs):
        """
        Set model settings either from a YAML file or provided keyword arguments.

        This method allows setting model settings from multiple sources:
        1. If a `yaml_file` is provided, it loads the settings from that file.
        2. If additional settings are provided as keyword arguments (`kargs`), they overwrite
        the default or loaded settings.

        Args:
            yaml_file (Path, str, optional): Path to the YAML file containing the model settings.
            kargs (dict, optional): Dictionary of model settings to override the default ones.

        """
        # Accessing model settings (YAML FILE)
        if yaml_file:
            self._load_model_setting(yaml_file)

        # Accessing model settings
        for att, defvalue in DEFAULT_GBDT_SETTINGS.items():
            if att in kargs:
                setattr(self, att, kargs[att])
            else:
                try:
                    att_exist = getattr(self, att)
                    att_exist = att_exist == att_exist
                except AttributeError:
                    att_exist = False
                if not att_exist:
                    setattr(self, att, defvalue)
                    logger.warning(f"using default value {defvalue} for {att} setting")

        if yaml_file is None or os.path.dirname(str(yaml_file)) != str(self.models_dir):
            self._save_model_settings()

    def featurize_db(
        self,
        databases: list[BaseDatabase],
        subset: str | None = None,
        max_dof: int = None,
        mute_warnings: bool = True,
    ) -> pd.DataFrame:
        """
        Load and featurize the specified databases efficiently.

        This method processes the given list of databases and applies a featurization pipeline to each dataset.
        It handles the loading of raw data from databases, extracts relevant features such as composition and
        structure, and applies transformations to create numerical features for model training.

        Args:
            databases (list[BaseDatabase]): A list of databases (or a single database) from which to
                                            load and featurize data.
            subset (str, optional): The subset of data to load from each database (default is None).
            max_dof (int, optional): Maximum degrees of freedom for the feature space. If not provided,
                                        it is automatically calculated.
            mute_warnings (bool, optional): Whether to suppress warnings during featurization (default is True).

        Returns:
            (pd.DataFrame): A DataFrame containing the featurized data, including numerical features
                            and a column for `is_specialized`.

        Warnings:
            If `mute_warnings` is set to False, third-party warnings related to the featurization
            process will be displayed.

        Notes:
            - If the database contains a `Reduced Formula` or `formula_pretty` column, compositions are
              parsed accordingly.
            - The resulting DataFrame retains only numerical features, removes NaN rows, and includes an
              `is_specialized` column.
            - The method supports batch processing using `tqdm` for progress visualization during featurization.
        """

        if not isinstance(databases, list):
            databases = [databases]

        if mute_warnings:
            logger.warning(
                "Third-party warnings disabled. Set 'mute_warnings=False' to enable them."
            )
            self.max_dof = max_dof

        if isinstance(databases[0], GNoMEDatabase):
            dataset = databases[0].get_database("raw")
        else:
            # Load all database subsets efficiently
            dataset = pd.concat(
                [db.load_classifier_data(subset) for db in databases], ignore_index=True
            )

        if dataset.empty:
            logger.warning("No data loaded for featurization.")
            return pd.DataFrame()

        with warnings.catch_warnings():
            if mute_warnings:
                warnings.simplefilter("ignore")

            # Apply batch processing for transformations
            tqdm.pandas(desc="Processing structures")
            if isinstance(databases[0], GNoMEDatabase):
                dataset["composition"] = dataset["Reduced Formula"].progress_apply(Composition)
            else:
                dataset["composition"] = dataset["formula_pretty"].progress_apply(Composition)
            dataset["structure"] = dataset["cif_path"].progress_apply(
                lambda x: Structure.from_file(x)
            )
            dataset["formula"] = dataset["composition"].apply(lambda x: x.reduced_formula)
            dataset["is_specialized"] = dataset["is_specialized"].astype(float)

            # Run the featurization pipeline
            if isinstance(databases[0], GNoMEDatabase):
                db_feature = featurizing_structure_pipeline(dataset)
            else:
                db_feature = featurizing_structure_pipeline(dataset)

            # Retain only numerical features
            db_feature = db_feature.select_dtypes(exclude=["object", "bool"])
            db_feature["is_specialized"] = dataset.set_index("formula")["is_specialized"]
            db_feature.dropna(inplace=True)  # Remove NaN rows

            # Adjust max_dof if needed
            if max_dof is None:
                n_items, n_features = db_feature.shape
                self.max_dof = min(n_items // 10, n_features - 1)

        if not isinstance(databases[0], GNoMEDatabase):
            # Log class distribution
            specialized_counts = db_feature["is_specialized"].value_counts()
            logger.debug(f"Number of specialized examples: {specialized_counts.get(1.0, 0)}")
            logger.debug(f"Number of non-specialized examples: {specialized_counts.get(0.0, 0)}")

        return db_feature

    def _build_classifier(self, n_jobs: int = 1, max_dof: int = None):
        """
        Builds a Gradient Boosting Classifier pipeline with feature selection.

        This method constructs a classification pipeline that includes:
        1. Standard scaling of features.
        2. Feature selection using Recursive Feature Elimination (RFE) with a Gradient Boosting Classifier
           as the estimator.
        3. A Gradient Boosting Classifier as the final model for classification.

        It performs a grid search to tune hyperparameters for the number of estimators and the number of
        features to select.
        The search uses Stratified K-Folds cross-validation and evaluates performance using the ROC-AUC score.

        Args:
            n_jobs (int, optional): The number of CPU cores to use for parallel processing during the grid search.
                                    Default is 1.
            max_dof (int, optional): The maximum degrees of freedom for feature selection. If provided,
                                     it will control the number of features to select.

        Returns:
            (GridSearchCV): A GridSearchCV object that encapsulates the classifier pipeline and hyperparameter
                            tuning process.

        Notes:
            - The `n_estimators` in the classifier will be searched over the values [50, 100, 250, 500].
            - The number of features to select for RFE is determined by the `max_dof` argument,
              with values [int(max_dof * 0.5), int(max_dof * 1)].
            - Stratified K-Folds cross-validation is used with 4 splits, shuffling enabled, and a fixed random seed (0).
            - ROC-AUC score is used for model evaluation.
        """
        pipe = Pipeline(
            [
                ("scaler", StandardScaler()),
                ("feature_selector", RFE(estimator=GradientBoostingClassifier(), step=25)),
                ("classifier", GradientBoostingClassifier()),
            ],
        )
        param_grid = {
            "classifier__n_estimators": [50, 100, 250, 500],
            "feature_selector__n_features_to_select": [int(max_dof * 0.5), int(max_dof * 1)],
        }

        stratified_kfold = StratifiedKFold(n_splits=4, shuffle=True, random_state=0)
        search = GridSearchCV(
            pipe, param_grid, n_jobs=n_jobs, verbose=2, cv=stratified_kfold, scoring="roc_auc"
        )

        return search

    def compile_(  # same here, _ to mute the pre-commit
        self,
        n_jobs: int = 1,
    ):
        """
        Initializes the classifier pipeline and sets up the hyperparameter search.

        This method calls the `_build_classifier` method to construct a Gradient Boosting Classifier pipeline,
        including feature scaling, feature selection via Recursive Feature Elimination (RFE), and classification.
        It then stores the resulting `GridSearchCV` object in the `self.search` attribute.

        Args:
            n_jobs (int, optional): The number of CPU cores to use for parallel processing during the grid search.
                                    Default is 1.

        Returns:
            None

        Notes:
            This method does not return any value but sets the `self.search` attribute with the initialized
            `GridSearchCV` object, which contains the classifier pipeline and hyperparameter tuning setup.
        """

        logger.info("Build classifiers")
        self.search = self._build_classifier(n_jobs=n_jobs, max_dof=self.max_dof)

    def fit(self, df: pd.DataFrame):
        """
        Train and save multiple models using a `GridSearchCV` classifier.

        This method iterates through the specified number of committers (`n_committers`) and performs the following:
        1. Trains a model for each committer using the `GridSearchCV` pipeline (`self.search`) with the given
           data (`df`).
        2. Saves each trained model as a `.pkl` file in the specified `self.models_dir`.

        Args:
            df (pd.DataFrame): The input dataset. The last column is assumed to be the target variable,
                                and all other columns are used as features.

        Returns:
            None

        Notes:
            - The models are saved as `.pkl` files in the `self.models_dir` directory, with filenames
            following the pattern `{self._model_spec}.rep{i}.pkl`, where `i` is the index of the committer.
            - The `GridSearchCV` pipeline, defined in `self.search`, is used for training.
        """
        models = {}
        for i in tqdm(range(self.n_committers), desc="training models"):
            models[f"model_{i}"] = self.search.fit(df.iloc[:, :-1], df.iloc[:, -1])

        for i in tqdm(range(self.n_committers), desc="saving"):
            model_path = self.models_dir / (self._model_spec + f".rep{i}.pkl")
            model_ = models[f"model_{i}"]
            with open(model_path, "wb") as file_:
                pickle.dump(model_, file_)

    def evaluate(self, df: pd.DataFrame, return_df: bool = False):
        """
        Evaluate the performance of multiple models on a given dataset.

        This method evaluates each model's predictions on the provided dataset (`df`) and returns the predictions
        either as a DataFrame with true values and model predictions or as a dictionary of model predictions.

        Args:
            df (pd.DataFrame): The dataset containing the features and the target property (last column).
                            The model predictions are based on all columns except the last one (target property).
            return_df (bool, optional): If True, returns a DataFrame with true values and predictions from each model.
                                        If False, returns a dictionary with model predictions. Defaults to `False`.

        Returns:
            (pd.DataFrame): If `return_df=True`, returns a pandas DataFrame where each column
                            corresponds to predictions and loss for each model.
                            The columns include:
                                - `true_value`: Ground truth values.
                                - `model_i_prediction`: Predictions from model `i`.

            (dict[str, pd.DataFrame]): If `return_df=False`, returns a dictionary where each key is
                    a model identifier (e.g., `model_0`, `model_1`, ...) and the value is a
                    DataFrame containing the following columns:
                        - `true_value`: Ground truth values.
                        - `prediction`: Predictions from the model.

        Notes:
            - The target property in `df` must match `self.target_property`.
            - Each model's predictions are generated using the `predict_proba` method, which is expected to
              return probabilities.
            - The method assumes the target property is in the last column of the input `df` and features are in all
              other columns.
        """
        if return_df:
            predictions = pd.DataFrame(columns=["true"])
            predictions["true"] = df[self.target_property]

            for i in tqdm(range(self.n_committers), desc="models"):
                predictions[f"model_{i}"] = np.empty((len(df), 1)).tolist()
                predictions[f"model_{i}"] = self.models[f"model_{i}"].predict_proba(
                    df.iloc[:, :-1]
                )[:, 1]

        else:
            predictions = {}
            for i in tqdm(range(self.n_committers), desc="models"):
                predictions[f"model_{i}"] = pd.DataFrame()
                predictions[f"model_{i}"]["true_value"] = df[self.target_property]
                predictions[f"model_{i}"]["prediction"] = self.models[f"model_{i}"].predict_proba(
                    df.iloc[:, :-1]
                )[:, 1]

        return predictions

    def plot_performance(
        self, predictions_dict: dict[str, pd.DataFrame], include_ensemble: bool = True
    ):
        """
        Plot model performance evaluation curves: ROC, Precision, and Recall.

        This method generates a multi-panel plot that visualizes the performance of different models on
        classification tasks. It includes:
        - ROC curve with AUC (Area Under the Curve)
        - Precision-Recall curve
        - Recall-Threshold curve

        The method also supports an optional ensemble model performance evaluation by averaging
        individual model predictions.

        Args:
            predictions_dict (dict[str, pd.DataFrame]): A dictionary where keys are model names and values are
                                                        DataFrames containing the `true_value` and `prediction` columns.
                                                        Each model's predictions will be plotted.
            include_ensemble (bool, optional): If `True`, the ensemble model performance will also be plotted, which is
                                            based on averaging the predictions of all models. Defaults to `True`.

        Returns:
            None: The method generates and saves the performance plots as PNG and PDF files.

        Notes:
            - The method assumes that the `predictions_dict` contains the model predictions (in the `prediction` column)
            and the true labels (in the `true_value` column).
            - The ROC curve is evaluated using the `roc_curve` function, while the Precision and Recall curves are
            generated using `precision_recall_curve`.
            - The final figure is saved in both PNG and PDF formats in the directory defined by `self.figures_dir`.
        """

        all_predictions = []
        colors = plt.cm.tab10.colors

        # Create a single figure and a grid layout
        fig = plt.figure(figsize=(7, 2), dpi=150)
        grid = fig.add_gridspec(1, 3, wspace=0.5, hspace=0.07)
        ax = [fig.add_subplot(grid[j]) for j in range(3)]

        for i, (model, data) in enumerate(predictions_dict.items()):
            y_true = data["true_value"].values  # Ensure correct format
            y_predictions = data["prediction"].values

            if include_ensemble:
                all_predictions.append(y_predictions)

            fpr, tpr, _ = roc_curve(y_true, y_predictions)
            precision, recall, thresholds = precision_recall_curve(y_true, y_predictions)

            # Use colors from plt.cm.tab10
            color = colors[i % len(colors)]

            ax[0].plot(fpr, tpr, label=f"{model} (AUC: {auc(fpr, tpr):.2f})", color=color)
            ax[1].plot(thresholds, precision[:-1], label=f"{model}", color=color)
            ax[2].plot(thresholds, recall[:-1], label=f"{model}", color=color)

        # If ensemble is enabled, add an extra curve
        if include_ensemble and all_predictions:
            y_ensemble = np.mean(np.column_stack(all_predictions), axis=1)
            fpr, tpr, _ = roc_curve(y_true, y_ensemble)
            precision, recall, thresholds = precision_recall_curve(y_true, y_ensemble)

            ax[0].plot(fpr, tpr, label="Ensemble", color="black", linestyle="--")
            ax[1].plot(thresholds, precision[:-1], label="Ensemble", color="black", linestyle="--")
            ax[2].plot(thresholds, recall[:-1], label="Ensemble", color="black", linestyle="--")

        # Add diagonal reference line in ROC curve
        ax[0].axline([0.5, 0.5], slope=1, lw=0.85, ls="--", color="k", zorder=2)
        ax[0].set_xlabel("FPR")
        ax[0].set_ylabel("TPR")
        ax[0].set_title("ROC Curve")
        ax[0].set_aspect("equal")

        ax[1].set_xlabel("Threshold")
        ax[1].set_ylabel("Precision")
        ax[1].set_title("Precision Curve")
        ax[1].set_ylim([-0.05, 1.05])

        ax[2].set_xlabel("Threshold")
        ax[2].set_ylabel("Recall")
        ax[2].set_title("Recall Curve")
        ax[2].set_ylim([-0.05, 1.05])

        # Move legend outside the plot
        ax[2].legend(loc="upper left", bbox_to_anchor=(1.05, 1), fontsize=8, borderaxespad=0.0)

        fig.suptitle("Model Performance")
        fig.subplots_adjust(top=0.8, right=0.75)  # Adjust right margin for legend space

        # Save figure
        fig.savefig(self.figures_dir / (self._model_spec + ".png"), dpi=330, bbox_inches="tight")
        fig.savefig(self.figures_dir / (self._model_spec + ".pdf"), dpi=330, bbox_inches="tight")

        plt.show()

    def _evaluate_unknown(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Evaluate and predict the target property using a committee of classifiers.

        This method takes a DataFrame of input features, applies a set of pre-trained classifiers
        (stored in `self.models`), and generates probability predictions for each classifier in the committee.
        It then computes the mean and standard deviation of the classifier predictions to provide an
        overall prediction and uncertainty.

        Args:
            df (pd.DataFrame): A DataFrame where each row represents a sample, and the columns represent features.
                                The last column is assumed to be the target property, which is not used in prediction.

        Returns:
            (pd.DataFrame): A DataFrame containing:
                - The predicted probability from each classifier in the committee (labeled as `classifier_{i}`).
                - The mean of all classifier predictions (`classifier_mean`).
                - The standard deviation of the classifier predictions (`classifier_std`).

        Notes:
            - The input `df` should have one or more features for prediction, and the last column is ignored
              during prediction.
            - The method assumes that the classifiers in `self.models` are capable of generating probability predictions
            using the `predict_proba` method, and that the classifiers are indexed from `0` to `n_committers-1`.
            - The `classifier_mean` column represents the average of all classifiers' probability predictions.
            - The `classifier_std` column gives a measure of the variance (uncertainty) of the predictions across
              the classifiers.
        """
        predictions = pd.DataFrame(index=df.index)

        for i in tqdm(range(self.n_committers), desc="models"):
            predictions[f"classifier_{i}"] = self.models[f"model_{i}"].predict_proba(
                df.iloc[:, :-1]
            )[:, 1]
        predictions["classifier_mean"] = predictions[
            [f"classifier_{i}" for i in range(self.n_committers)]
        ].mean(axis=1)
        predictions["classifier_std"] = predictions[
            [f"classifier_{i}" for i in range(self.n_committers)]
        ].std(axis=1)
        return predictions

    def screen(self, db: GNoMEDatabase, save_processed: bool = True) -> pd.DataFrame:
        """
        Screen the database for specialized materials using classifier predictions.

        This method performs the following steps:
        1. Featurizes the database using `featurize_db`.
        2. Evaluates the featurized data using a committee of classifiers to generate predictions.
        3. Combines the predictions with the original database and removes rows with missing values
           or unqualified materials.
        4. Optionally saves the processed (screened) database for future use.

        Args:
            db (GNoMEDatabase): A `GNoMEDatabase` object containing the raw data to be screened.
            save_processed (bool, optional): Whether to save the screened data to the database. Defaults to `True`.

        Returns:
            (pd.DataFrame): A DataFrame containing the original data combined with classifier predictions,
                        excluding materials that have missing or unqualified values for screening.

        Notes:
            - The method assumes that `featurize_db` and `_evaluate_unknown` methods are defined and function correctly.
            - The `classifier_mean` column in the returned DataFrame reflects the mean classifier prediction, which is
              used to screen specialized materials.
            - The `is_specialized` column is dropped from the screened DataFrame.
        """
        logger.info("Featurizing the database...")
        df_class = self.featurize_db(db)
        logger.info("Screening the database for specialized materials.")
        predictions = self._evaluate_unknown(df_class)
        gnome_df = db.get_database("raw")
        gnome_screened = pd.concat([gnome_df, predictions.reset_index(drop=True)], axis=1)
        gnome_screened.drop(columns=["is_specialized"], inplace=True)
        gnome_screened = gnome_screened[gnome_screened["classifier_mean"].notna()]

        if save_processed:
            logger.info("Saving the screened database.")
            db.databases["processed"] = gnome_screened.copy()
            db.save_database("processed")

        return gnome_screened

__init__(model_name, models_dir=MODELS_DIR, figures_dir=FIGURES_DIR)

This class extends BaseClassifier to implement a gradient boosted decision tree (GBDT) for classification tasks. It sets up the necessary directory structure and configurations for training models.

Parameters:

Name Type Description Default
model_name str

Name of the model, used to create subdirectories.

required
models_dir Path | str

Directory for storing trained model weights. Defaults to MODELS_DIR from config.

MODELS_DIR
figures_dir Path | str

Directory for saving figures and visualizations. Defaults to FIGURES_DIR from config.

FIGURES_DIR

Attributes:

Name Type Description
_model_spec str

Specification string used for model identification.

Source code in energy_gnome/models/gbdt/classifier.py
Python
def __init__(
    self,
    model_name: str,
    models_dir: Path | str = MODELS_DIR,
    figures_dir: Path | str = FIGURES_DIR,
):
    """
    Initialize the GBDTClassifier with directories for storing models and figures.

    This class extends `BaseClassifier` to implement a gradient boosted decision tree (GBDT)
    for classification tasks. It sets up the necessary directory structure and configurations
    for training models.

    Args:
        model_name (str): Name of the model, used to create subdirectories.
        models_dir (Path | str, optional): Directory for storing trained model weights.
                                        Defaults to MODELS_DIR from config.
        figures_dir (Path | str, optional): Directory for saving figures and visualizations.
                                            Defaults to FIGURES_DIR from config.

    Attributes:
        _model_spec (str): Specification string used for model identification.
    """
    self._model_spec = "model.gbdt_classifier"
    super().__init__(model_name=model_name, models_dir=models_dir, figures_dir=figures_dir)

compile_(n_jobs=1)

Initializes the classifier pipeline and sets up the hyperparameter search.

This method calls the _build_classifier method to construct a Gradient Boosting Classifier pipeline, including feature scaling, feature selection via Recursive Feature Elimination (RFE), and classification. It then stores the resulting GridSearchCV object in the self.search attribute.

Parameters:

Name Type Description Default
n_jobs int

The number of CPU cores to use for parallel processing during the grid search. Default is 1.

1

Returns:

Type Description

None

Notes

This method does not return any value but sets the self.search attribute with the initialized GridSearchCV object, which contains the classifier pipeline and hyperparameter tuning setup.

Source code in energy_gnome/models/gbdt/classifier.py
Python
def compile_(  # same here, _ to mute the pre-commit
    self,
    n_jobs: int = 1,
):
    """
    Initializes the classifier pipeline and sets up the hyperparameter search.

    This method calls the `_build_classifier` method to construct a Gradient Boosting Classifier pipeline,
    including feature scaling, feature selection via Recursive Feature Elimination (RFE), and classification.
    It then stores the resulting `GridSearchCV` object in the `self.search` attribute.

    Args:
        n_jobs (int, optional): The number of CPU cores to use for parallel processing during the grid search.
                                Default is 1.

    Returns:
        None

    Notes:
        This method does not return any value but sets the `self.search` attribute with the initialized
        `GridSearchCV` object, which contains the classifier pipeline and hyperparameter tuning setup.
    """

    logger.info("Build classifiers")
    self.search = self._build_classifier(n_jobs=n_jobs, max_dof=self.max_dof)

evaluate(df, return_df=False)

Evaluate the performance of multiple models on a given dataset.

This method evaluates each model's predictions on the provided dataset (df) and returns the predictions either as a DataFrame with true values and model predictions or as a dictionary of model predictions.

Parameters:

Name Type Description Default
df DataFrame

The dataset containing the features and the target property (last column). The model predictions are based on all columns except the last one (target property).

required
return_df bool

If True, returns a DataFrame with true values and predictions from each model. If False, returns a dictionary with model predictions. Defaults to False.

False

Returns:

Type Description
DataFrame

If return_df=True, returns a pandas DataFrame where each column corresponds to predictions and loss for each model. The columns include: - true_value: Ground truth values. - model_i_prediction: Predictions from model i.

dict[str, DataFrame]

If return_df=False, returns a dictionary where each key is a model identifier (e.g., model_0, model_1, ...) and the value is a DataFrame containing the following columns: - true_value: Ground truth values. - prediction: Predictions from the model.

Notes
  • The target property in df must match self.target_property.
  • Each model's predictions are generated using the predict_proba method, which is expected to return probabilities.
  • The method assumes the target property is in the last column of the input df and features are in all other columns.
Source code in energy_gnome/models/gbdt/classifier.py
Python
def evaluate(self, df: pd.DataFrame, return_df: bool = False):
    """
    Evaluate the performance of multiple models on a given dataset.

    This method evaluates each model's predictions on the provided dataset (`df`) and returns the predictions
    either as a DataFrame with true values and model predictions or as a dictionary of model predictions.

    Args:
        df (pd.DataFrame): The dataset containing the features and the target property (last column).
                        The model predictions are based on all columns except the last one (target property).
        return_df (bool, optional): If True, returns a DataFrame with true values and predictions from each model.
                                    If False, returns a dictionary with model predictions. Defaults to `False`.

    Returns:
        (pd.DataFrame): If `return_df=True`, returns a pandas DataFrame where each column
                        corresponds to predictions and loss for each model.
                        The columns include:
                            - `true_value`: Ground truth values.
                            - `model_i_prediction`: Predictions from model `i`.

        (dict[str, pd.DataFrame]): If `return_df=False`, returns a dictionary where each key is
                a model identifier (e.g., `model_0`, `model_1`, ...) and the value is a
                DataFrame containing the following columns:
                    - `true_value`: Ground truth values.
                    - `prediction`: Predictions from the model.

    Notes:
        - The target property in `df` must match `self.target_property`.
        - Each model's predictions are generated using the `predict_proba` method, which is expected to
          return probabilities.
        - The method assumes the target property is in the last column of the input `df` and features are in all
          other columns.
    """
    if return_df:
        predictions = pd.DataFrame(columns=["true"])
        predictions["true"] = df[self.target_property]

        for i in tqdm(range(self.n_committers), desc="models"):
            predictions[f"model_{i}"] = np.empty((len(df), 1)).tolist()
            predictions[f"model_{i}"] = self.models[f"model_{i}"].predict_proba(
                df.iloc[:, :-1]
            )[:, 1]

    else:
        predictions = {}
        for i in tqdm(range(self.n_committers), desc="models"):
            predictions[f"model_{i}"] = pd.DataFrame()
            predictions[f"model_{i}"]["true_value"] = df[self.target_property]
            predictions[f"model_{i}"]["prediction"] = self.models[f"model_{i}"].predict_proba(
                df.iloc[:, :-1]
            )[:, 1]

    return predictions

featurize_db(databases, subset=None, max_dof=None, mute_warnings=True)

Load and featurize the specified databases efficiently.

This method processes the given list of databases and applies a featurization pipeline to each dataset. It handles the loading of raw data from databases, extracts relevant features such as composition and structure, and applies transformations to create numerical features for model training.

Parameters:

Name Type Description Default
databases list[BaseDatabase]

A list of databases (or a single database) from which to load and featurize data.

required
subset str

The subset of data to load from each database (default is None).

None
max_dof int

Maximum degrees of freedom for the feature space. If not provided, it is automatically calculated.

None
mute_warnings bool

Whether to suppress warnings during featurization (default is True).

True

Returns:

Type Description
DataFrame

A DataFrame containing the featurized data, including numerical features and a column for is_specialized.

Notes
  • If the database contains a Reduced Formula or formula_pretty column, compositions are parsed accordingly.
  • The resulting DataFrame retains only numerical features, removes NaN rows, and includes an is_specialized column.
  • The method supports batch processing using tqdm for progress visualization during featurization.
Source code in energy_gnome/models/gbdt/classifier.py
Python
def featurize_db(
    self,
    databases: list[BaseDatabase],
    subset: str | None = None,
    max_dof: int = None,
    mute_warnings: bool = True,
) -> pd.DataFrame:
    """
    Load and featurize the specified databases efficiently.

    This method processes the given list of databases and applies a featurization pipeline to each dataset.
    It handles the loading of raw data from databases, extracts relevant features such as composition and
    structure, and applies transformations to create numerical features for model training.

    Args:
        databases (list[BaseDatabase]): A list of databases (or a single database) from which to
                                        load and featurize data.
        subset (str, optional): The subset of data to load from each database (default is None).
        max_dof (int, optional): Maximum degrees of freedom for the feature space. If not provided,
                                    it is automatically calculated.
        mute_warnings (bool, optional): Whether to suppress warnings during featurization (default is True).

    Returns:
        (pd.DataFrame): A DataFrame containing the featurized data, including numerical features
                        and a column for `is_specialized`.

    Warnings:
        If `mute_warnings` is set to False, third-party warnings related to the featurization
        process will be displayed.

    Notes:
        - If the database contains a `Reduced Formula` or `formula_pretty` column, compositions are
          parsed accordingly.
        - The resulting DataFrame retains only numerical features, removes NaN rows, and includes an
          `is_specialized` column.
        - The method supports batch processing using `tqdm` for progress visualization during featurization.
    """

    if not isinstance(databases, list):
        databases = [databases]

    if mute_warnings:
        logger.warning(
            "Third-party warnings disabled. Set 'mute_warnings=False' to enable them."
        )
        self.max_dof = max_dof

    if isinstance(databases[0], GNoMEDatabase):
        dataset = databases[0].get_database("raw")
    else:
        # Load all database subsets efficiently
        dataset = pd.concat(
            [db.load_classifier_data(subset) for db in databases], ignore_index=True
        )

    if dataset.empty:
        logger.warning("No data loaded for featurization.")
        return pd.DataFrame()

    with warnings.catch_warnings():
        if mute_warnings:
            warnings.simplefilter("ignore")

        # Apply batch processing for transformations
        tqdm.pandas(desc="Processing structures")
        if isinstance(databases[0], GNoMEDatabase):
            dataset["composition"] = dataset["Reduced Formula"].progress_apply(Composition)
        else:
            dataset["composition"] = dataset["formula_pretty"].progress_apply(Composition)
        dataset["structure"] = dataset["cif_path"].progress_apply(
            lambda x: Structure.from_file(x)
        )
        dataset["formula"] = dataset["composition"].apply(lambda x: x.reduced_formula)
        dataset["is_specialized"] = dataset["is_specialized"].astype(float)

        # Run the featurization pipeline
        if isinstance(databases[0], GNoMEDatabase):
            db_feature = featurizing_structure_pipeline(dataset)
        else:
            db_feature = featurizing_structure_pipeline(dataset)

        # Retain only numerical features
        db_feature = db_feature.select_dtypes(exclude=["object", "bool"])
        db_feature["is_specialized"] = dataset.set_index("formula")["is_specialized"]
        db_feature.dropna(inplace=True)  # Remove NaN rows

        # Adjust max_dof if needed
        if max_dof is None:
            n_items, n_features = db_feature.shape
            self.max_dof = min(n_items // 10, n_features - 1)

    if not isinstance(databases[0], GNoMEDatabase):
        # Log class distribution
        specialized_counts = db_feature["is_specialized"].value_counts()
        logger.debug(f"Number of specialized examples: {specialized_counts.get(1.0, 0)}")
        logger.debug(f"Number of non-specialized examples: {specialized_counts.get(0.0, 0)}")

    return db_feature

fit(df)

Train and save multiple models using a GridSearchCV classifier.

This method iterates through the specified number of committers (n_committers) and performs the following: 1. Trains a model for each committer using the GridSearchCV pipeline (self.search) with the given data (df). 2. Saves each trained model as a .pkl file in the specified self.models_dir.

Parameters:

Name Type Description Default
df DataFrame

The input dataset. The last column is assumed to be the target variable, and all other columns are used as features.

required

Returns:

Type Description

None

Notes
  • The models are saved as .pkl files in the self.models_dir directory, with filenames following the pattern {self._model_spec}.rep{i}.pkl, where i is the index of the committer.
  • The GridSearchCV pipeline, defined in self.search, is used for training.
Source code in energy_gnome/models/gbdt/classifier.py
Python
def fit(self, df: pd.DataFrame):
    """
    Train and save multiple models using a `GridSearchCV` classifier.

    This method iterates through the specified number of committers (`n_committers`) and performs the following:
    1. Trains a model for each committer using the `GridSearchCV` pipeline (`self.search`) with the given
       data (`df`).
    2. Saves each trained model as a `.pkl` file in the specified `self.models_dir`.

    Args:
        df (pd.DataFrame): The input dataset. The last column is assumed to be the target variable,
                            and all other columns are used as features.

    Returns:
        None

    Notes:
        - The models are saved as `.pkl` files in the `self.models_dir` directory, with filenames
        following the pattern `{self._model_spec}.rep{i}.pkl`, where `i` is the index of the committer.
        - The `GridSearchCV` pipeline, defined in `self.search`, is used for training.
    """
    models = {}
    for i in tqdm(range(self.n_committers), desc="training models"):
        models[f"model_{i}"] = self.search.fit(df.iloc[:, :-1], df.iloc[:, -1])

    for i in tqdm(range(self.n_committers), desc="saving"):
        model_path = self.models_dir / (self._model_spec + f".rep{i}.pkl")
        model_ = models[f"model_{i}"]
        with open(model_path, "wb") as file_:
            pickle.dump(model_, file_)

load_trained_models()

Load trained models from the model directory.

This method searches for trained models by: 1. Loading model settings from .yaml files matching _model_spec. 2. Loading model weights from .pkl files. 3. Storing the loaded models in self.models.

Returns:

Type Description
list[str]

list[str]: A list of .pkl model filenames that were found in the directory.

Source code in energy_gnome/models/gbdt/classifier.py
Python
def load_trained_models(self) -> list[str]:
    """
    Load trained models from the model directory.

    This method searches for trained models by:
    1. Loading model settings from `.yaml` files matching `_model_spec`.
    2. Loading model weights from `.pkl` files.
    3. Storing the loaded models in `self.models`.

    Returns:
        list[str]: A list of `.pkl` model filenames that were found in the directory.
    """
    # Load model settings from YAML files matching the model spec
    for yaml_path in self.models_dir.glob(f"*{self._model_spec}*.yaml"):
        self.set_model_settings(yaml_file=yaml_path)

    i = 0
    loaded_models = []
    for model_path in self.models_dir.glob("*.pkl"):
        try:
            logger.info(f"Loading model from {model_path}")
            with open(model_path, "rb") as file_:
                model = pickle.load(file_)
                model.pool = True
                self.models[f"model_{i}"] = model
                loaded_models.append(model_path.name)
                i += 1
        except Exception as e:
            logger.error(f"Error loading model {model_path.name}: {e}")

    return loaded_models

plot_performance(predictions_dict, include_ensemble=True)

Plot model performance evaluation curves: ROC, Precision, and Recall.

This method generates a multi-panel plot that visualizes the performance of different models on classification tasks. It includes: - ROC curve with AUC (Area Under the Curve) - Precision-Recall curve - Recall-Threshold curve

The method also supports an optional ensemble model performance evaluation by averaging individual model predictions.

Parameters:

Name Type Description Default
predictions_dict dict[str, DataFrame]

A dictionary where keys are model names and values are DataFrames containing the true_value and prediction columns. Each model's predictions will be plotted.

required
include_ensemble bool

If True, the ensemble model performance will also be plotted, which is based on averaging the predictions of all models. Defaults to True.

True

Returns:

Name Type Description
None

The method generates and saves the performance plots as PNG and PDF files.

Notes
  • The method assumes that the predictions_dict contains the model predictions (in the prediction column) and the true labels (in the true_value column).
  • The ROC curve is evaluated using the roc_curve function, while the Precision and Recall curves are generated using precision_recall_curve.
  • The final figure is saved in both PNG and PDF formats in the directory defined by self.figures_dir.
Source code in energy_gnome/models/gbdt/classifier.py
Python
def plot_performance(
    self, predictions_dict: dict[str, pd.DataFrame], include_ensemble: bool = True
):
    """
    Plot model performance evaluation curves: ROC, Precision, and Recall.

    This method generates a multi-panel plot that visualizes the performance of different models on
    classification tasks. It includes:
    - ROC curve with AUC (Area Under the Curve)
    - Precision-Recall curve
    - Recall-Threshold curve

    The method also supports an optional ensemble model performance evaluation by averaging
    individual model predictions.

    Args:
        predictions_dict (dict[str, pd.DataFrame]): A dictionary where keys are model names and values are
                                                    DataFrames containing the `true_value` and `prediction` columns.
                                                    Each model's predictions will be plotted.
        include_ensemble (bool, optional): If `True`, the ensemble model performance will also be plotted, which is
                                        based on averaging the predictions of all models. Defaults to `True`.

    Returns:
        None: The method generates and saves the performance plots as PNG and PDF files.

    Notes:
        - The method assumes that the `predictions_dict` contains the model predictions (in the `prediction` column)
        and the true labels (in the `true_value` column).
        - The ROC curve is evaluated using the `roc_curve` function, while the Precision and Recall curves are
        generated using `precision_recall_curve`.
        - The final figure is saved in both PNG and PDF formats in the directory defined by `self.figures_dir`.
    """

    all_predictions = []
    colors = plt.cm.tab10.colors

    # Create a single figure and a grid layout
    fig = plt.figure(figsize=(7, 2), dpi=150)
    grid = fig.add_gridspec(1, 3, wspace=0.5, hspace=0.07)
    ax = [fig.add_subplot(grid[j]) for j in range(3)]

    for i, (model, data) in enumerate(predictions_dict.items()):
        y_true = data["true_value"].values  # Ensure correct format
        y_predictions = data["prediction"].values

        if include_ensemble:
            all_predictions.append(y_predictions)

        fpr, tpr, _ = roc_curve(y_true, y_predictions)
        precision, recall, thresholds = precision_recall_curve(y_true, y_predictions)

        # Use colors from plt.cm.tab10
        color = colors[i % len(colors)]

        ax[0].plot(fpr, tpr, label=f"{model} (AUC: {auc(fpr, tpr):.2f})", color=color)
        ax[1].plot(thresholds, precision[:-1], label=f"{model}", color=color)
        ax[2].plot(thresholds, recall[:-1], label=f"{model}", color=color)

    # If ensemble is enabled, add an extra curve
    if include_ensemble and all_predictions:
        y_ensemble = np.mean(np.column_stack(all_predictions), axis=1)
        fpr, tpr, _ = roc_curve(y_true, y_ensemble)
        precision, recall, thresholds = precision_recall_curve(y_true, y_ensemble)

        ax[0].plot(fpr, tpr, label="Ensemble", color="black", linestyle="--")
        ax[1].plot(thresholds, precision[:-1], label="Ensemble", color="black", linestyle="--")
        ax[2].plot(thresholds, recall[:-1], label="Ensemble", color="black", linestyle="--")

    # Add diagonal reference line in ROC curve
    ax[0].axline([0.5, 0.5], slope=1, lw=0.85, ls="--", color="k", zorder=2)
    ax[0].set_xlabel("FPR")
    ax[0].set_ylabel("TPR")
    ax[0].set_title("ROC Curve")
    ax[0].set_aspect("equal")

    ax[1].set_xlabel("Threshold")
    ax[1].set_ylabel("Precision")
    ax[1].set_title("Precision Curve")
    ax[1].set_ylim([-0.05, 1.05])

    ax[2].set_xlabel("Threshold")
    ax[2].set_ylabel("Recall")
    ax[2].set_title("Recall Curve")
    ax[2].set_ylim([-0.05, 1.05])

    # Move legend outside the plot
    ax[2].legend(loc="upper left", bbox_to_anchor=(1.05, 1), fontsize=8, borderaxespad=0.0)

    fig.suptitle("Model Performance")
    fig.subplots_adjust(top=0.8, right=0.75)  # Adjust right margin for legend space

    # Save figure
    fig.savefig(self.figures_dir / (self._model_spec + ".png"), dpi=330, bbox_inches="tight")
    fig.savefig(self.figures_dir / (self._model_spec + ".pdf"), dpi=330, bbox_inches="tight")

    plt.show()

screen(db, save_processed=True)

Screen the database for specialized materials using classifier predictions.

This method performs the following steps: 1. Featurizes the database using featurize_db. 2. Evaluates the featurized data using a committee of classifiers to generate predictions. 3. Combines the predictions with the original database and removes rows with missing values or unqualified materials. 4. Optionally saves the processed (screened) database for future use.

Parameters:

Name Type Description Default
db GNoMEDatabase

A GNoMEDatabase object containing the raw data to be screened.

required
save_processed bool

Whether to save the screened data to the database. Defaults to True.

True

Returns:

Type Description
DataFrame

A DataFrame containing the original data combined with classifier predictions, excluding materials that have missing or unqualified values for screening.

Notes
  • The method assumes that featurize_db and _evaluate_unknown methods are defined and function correctly.
  • The classifier_mean column in the returned DataFrame reflects the mean classifier prediction, which is used to screen specialized materials.
  • The is_specialized column is dropped from the screened DataFrame.
Source code in energy_gnome/models/gbdt/classifier.py
Python
def screen(self, db: GNoMEDatabase, save_processed: bool = True) -> pd.DataFrame:
    """
    Screen the database for specialized materials using classifier predictions.

    This method performs the following steps:
    1. Featurizes the database using `featurize_db`.
    2. Evaluates the featurized data using a committee of classifiers to generate predictions.
    3. Combines the predictions with the original database and removes rows with missing values
       or unqualified materials.
    4. Optionally saves the processed (screened) database for future use.

    Args:
        db (GNoMEDatabase): A `GNoMEDatabase` object containing the raw data to be screened.
        save_processed (bool, optional): Whether to save the screened data to the database. Defaults to `True`.

    Returns:
        (pd.DataFrame): A DataFrame containing the original data combined with classifier predictions,
                    excluding materials that have missing or unqualified values for screening.

    Notes:
        - The method assumes that `featurize_db` and `_evaluate_unknown` methods are defined and function correctly.
        - The `classifier_mean` column in the returned DataFrame reflects the mean classifier prediction, which is
          used to screen specialized materials.
        - The `is_specialized` column is dropped from the screened DataFrame.
    """
    logger.info("Featurizing the database...")
    df_class = self.featurize_db(db)
    logger.info("Screening the database for specialized materials.")
    predictions = self._evaluate_unknown(df_class)
    gnome_df = db.get_database("raw")
    gnome_screened = pd.concat([gnome_df, predictions.reset_index(drop=True)], axis=1)
    gnome_screened.drop(columns=["is_specialized"], inplace=True)
    gnome_screened = gnome_screened[gnome_screened["classifier_mean"].notna()]

    if save_processed:
        logger.info("Saving the screened database.")
        db.databases["processed"] = gnome_screened.copy()
        db.save_database("processed")

    return gnome_screened

set_model_settings(yaml_file=None, **kargs)

Set model settings either from a YAML file or provided keyword arguments.

This method allows setting model settings from multiple sources: 1. If a yaml_file is provided, it loads the settings from that file. 2. If additional settings are provided as keyword arguments (kargs), they overwrite the default or loaded settings.

Parameters:

Name Type Description Default
yaml_file (Path, str)

Path to the YAML file containing the model settings.

None
kargs dict

Dictionary of model settings to override the default ones.

{}
Source code in energy_gnome/models/gbdt/classifier.py
Python
def set_model_settings(self, yaml_file: Path | str | None = None, **kargs):
    """
    Set model settings either from a YAML file or provided keyword arguments.

    This method allows setting model settings from multiple sources:
    1. If a `yaml_file` is provided, it loads the settings from that file.
    2. If additional settings are provided as keyword arguments (`kargs`), they overwrite
    the default or loaded settings.

    Args:
        yaml_file (Path, str, optional): Path to the YAML file containing the model settings.
        kargs (dict, optional): Dictionary of model settings to override the default ones.

    """
    # Accessing model settings (YAML FILE)
    if yaml_file:
        self._load_model_setting(yaml_file)

    # Accessing model settings
    for att, defvalue in DEFAULT_GBDT_SETTINGS.items():
        if att in kargs:
            setattr(self, att, kargs[att])
        else:
            try:
                att_exist = getattr(self, att)
                att_exist = att_exist == att_exist
            except AttributeError:
                att_exist = False
            if not att_exist:
                setattr(self, att, defvalue)
                logger.warning(f"using default value {defvalue} for {att} setting")

    if yaml_file is None or os.path.dirname(str(yaml_file)) != str(self.models_dir):
        self._save_model_settings()