Skip to content

Dataset Module

energy_gnome.dataset

energy_gnome.dataset.BaseDatabase

Bases: ABC

Abstract base class for managing a structured database system with multiple processing stages and data subsets.

This class provides a standardized framework for handling data across different stages of processing (raw, processed, final). It ensures proper directory structure, initializes database placeholders, and offers an interface for subclassing specialized database implementations.

Attributes:

Name Type Description
name str

The name of the database instance.

data_dir Path

Root directory where database files are stored.

processing_stages list[str]

The main stages of data processing.

interim_sets list[str]

Subsets within the training pipelines (e.g., training, validation).

database_directories dict[str, Path]

Mapping of processing stages to their respective directories.

database_paths dict[str, Path]

Paths to database files for each processing stage.

databases dict[str, DataFrame]

Data storage for each processing stage.

subset dict[str, DataFrame]

Storage for subsets like training, validation, and testing.

_update_raw bool

Flag indicating whether raw data should be updated.

_is_specialized bool

Indicates whether a subclass contains materials for specialized energy applications.

Source code in energy_gnome/dataset/base_dataset.py
Python
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
class BaseDatabase(ABC):
    """
    Abstract base class for managing a structured database system with multiple
    processing stages and data subsets.

    This class provides a standardized framework for handling data across different
    stages of processing (`raw`, `processed`, `final`). It ensures proper directory
    structure, initializes database placeholders, and offers an interface for
    subclassing specialized database implementations.

    Attributes:
        name (str): The name of the database instance.
        data_dir (Path): Root directory where database files are stored.
        processing_stages (list[str]): The main stages of data processing.
        interim_sets (list[str]): Subsets within the training pipelines (e.g., training, validation).
        database_directories (dict[str, Path]): Mapping of processing stages to their respective directories.
        database_paths (dict[str, Path]): Paths to database files for each processing stage.
        databases (dict[str, pd.DataFrame]): Data storage for each processing stage.
        subset (dict[str, pd.DataFrame]): Storage for subsets like training, validation, and testing.
        _update_raw (bool): Flag indicating whether raw data should be updated.
        _is_specialized (bool): Indicates whether a subclass contains materials for specialized energy applications.
    """

    def __init__(self, name: str, data_dir: Path = DATA_DIR):
        """
        Initializes the BaseDatabase instance.

        Sets up the directory structure for storing data across different processing
        stages (`raw`, `processed`, `final`) and initializes empty Pandas DataFrames
        for managing the data.

        Args:
            name (str): The name of the database instance.
            data_dir (Path, optional): Root directory path for storing database files.
                Defaults to `DATA_DIR`.
        """
        self.name = name

        if isinstance(data_dir, str):
            self.data_dir = Path(data_dir)
        else:
            self.data_dir = data_dir
        self.data_dir.mkdir(parents=True, exist_ok=True)

        # Define processing stages
        self.processing_stages = ["raw", "processed", "final"]
        self.interim_sets = ["training", "validation", "testing"]

        # Initialize directories, paths, and databases for each stage
        self.database_directories = {
            stage: self.data_dir / stage / self.name for stage in self.processing_stages
        }
        for stage_dir in self.database_directories.values():
            stage_dir.mkdir(parents=True, exist_ok=True)

        self.database_paths = {
            stage: dir_path / "database.json"
            for stage, dir_path in self.database_directories.items()
        }

        self.databases = {stage: pd.DataFrame() for stage in self.processing_stages}
        self.subset = {subset: pd.DataFrame() for subset in self.interim_sets}
        self._update_raw = False
        self._is_specialized = False
        self._set_is_specialized()
        self.load_all()

    @abstractmethod
    def _set_is_specialized(
        self,
    ):
        """
        Set the `is_specialized` attribute.

        This method marks the database as specialized by setting the `is_specialized`
        attribute to `True`. It is typically used to indicate that the database
        is intended for a specific class of data, corresponding to specialized
        energy materials.

        Returns:
            None
        """
        pass

    def allow_raw_update(self):
        """
        Enables modifications to the `raw` data stage.

        This method sets the internal flag `_update_raw` to `True`, allowing changes
        to be made to the raw data without raising an exception.

        Warning:
            Use with caution, as modifying raw data can impact data integrity.
        """
        self._update_raw = True

    def compare_databases(self, new_db: pd.DataFrame, stage: str) -> pd.DataFrame:
        """
        Compare two databases and identify new entry IDs.

        This method compares an existing database (loaded from the specified stage) with a new database.
        It returns the entries from the new database that are not present in the existing one.

        Args:
            new_db (pd.DataFrame): New database to compare.
            stage (str): Processing stage ("raw", "processed", "final").

        Returns:
            pd.DataFrame: Subset of `new_db` containing only new entry IDs.

        Logs:
            - DEBUG: The number of new entries found.
            - WARNING: If the old database is empty and nothing can be compared.
        """
        old_db = self.get_database(stage=stage)
        if not old_db.empty:
            new_ids_set = set(new_db["material_id"])
            old_ids_set = set(old_db["material_id"])
            new_ids_only = new_ids_set - old_ids_set
            logger.debug(f"Found {len(new_ids_only)} new IDs in the new database.")
            return new_db[new_db["material_id"].isin(new_ids_only)]
        else:
            logger.warning("Nothing to compare here...")
            return new_db

    def backup_and_changelog(
        self,
        old_db: pd.DataFrame,
        new_db: pd.DataFrame,
        differences: pd.Series,
        stage: str,
    ) -> None:
        """
        Backup the old database and update the changelog with identified differences.

        This method saves a backup of the existing database before updating it with new data.
        It also logs the changes detected by comparing the old and new databases, storing the
        details in a changelog file.

        Args:
            old_db (pd.DataFrame): The existing database before updating.
            new_db (pd.DataFrame): The new database with updated entries.
            differences (pd.Series): A series containing the material IDs of entries that differ.
            stage (str): The processing stage ("raw", "processed", "final") for which the backup
                        and changelog are being maintained.

        Raises:
            ValueError: If an invalid `stage` is provided.
            OSError: If there is an issue writing to the backup or changelog files.

        Logs:
            - ERROR: If an invalid stage is provided.
            - DEBUG: When the old database is successfully backed up.
            - ERROR: If the backup process fails.
            - DEBUG: When the changelog is successfully updated with differences.
            - ERROR: If updating the changelog fails.
        """
        if stage not in self.processing_stages:
            logger.error(f"Invalid stage: {stage}. Must be one of {self.processing_stages}.")
            raise ValueError(f"stage must be one of {self.processing_stages}.")

        # Backup the old database
        backup_path = self.database_directories[stage] / "old_database.json"
        try:
            old_db.to_json(backup_path)
            logger.debug(f"Old database backed up to {backup_path}")
        except Exception as e:
            logger.error(f"Failed to backup old database to {backup_path}: {e}")
            raise OSError(f"Failed to backup old database to {backup_path}: {e}") from e

        # Prepare changelog
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        changelog_path = self.database_directories[stage] / "changelog.txt"

        header = (
            f"= Change Log - {timestamp} ".ljust(70, "=") + "\n"
            "Difference old_database.json VS database.json\n"
            f"{'ID':<15}{'Formula':<30}{'Last Updated':<25}\n" + "-" * 70 + "\n"
        )

        # Set index for faster lookup
        new_db_indexed = new_db.set_index("material_id")

        # Process differences efficiently
        changes = [
            f"{identifier:<15}{new_db_indexed.at[identifier, 'formula_pretty'] if identifier in new_db_indexed.index else 'N/A':<30}"
            f"{new_db_indexed.at[identifier, 'last_updated'] if identifier in new_db_indexed.index else 'N/A':<25}\n"
            for identifier in differences["material_id"]
        ]

        try:
            with open(changelog_path, "a") as file:
                file.write(header + "".join(changes))
            logger.debug(f"Changelog updated at {changelog_path} with {len(differences)} changes.")
        except Exception as e:
            logger.error(f"Failed to update changelog at {changelog_path}: {e}")
            raise OSError(f"Failed to update changelog at {changelog_path}: {e}") from e

    def compare_and_update(self, new_db: pd.DataFrame, stage: str) -> pd.DataFrame:
        """
        Compare and update the database with new entries.

        This method checks for new entries in the provided database and updates the stored
        database accordingly. It ensures that raw data remains immutable unless explicitly
        allowed. If new entries are found, the old database is backed up, and a changelog
        is created.

        Args:
            new_db (pd.DataFrame): The new database to compare against the existing one.
            stage (str): The processing stage ("raw", "processed", "final").

        Returns:
            pd.DataFrame: The updated database containing new entries.

        Raises:
            ImmutableRawDataError: If attempting to modify raw data without explicit permission.

        Logs:
            - WARNING: If new items are detected in the database.
            - ERROR: If an attempt is made to modify immutable raw data.
            - INFO: When saving or updating the database.
            - INFO: If no new items are found and no update is required.
        """
        old_db = self.get_database(stage=stage)
        db_diff = self.compare_databases(new_db, stage)
        if not db_diff.empty:
            logger.warning(f"The new database contains {len(db_diff)} new items.")

            if stage == "raw" and not self._update_raw:
                logger.error("Raw data must be treated as immutable!")
                logger.error(
                    "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
                )
                raise ImmutableRawDataError(
                    "Raw data must be treated as immutable!\n"
                    "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
                )
            else:
                if stage == "raw":
                    logger.info(
                        "Be careful you are changing the raw data which must be treated as immutable!"
                    )
                if old_db.empty:
                    logger.info(f"Saving new {stage} data in {self.database_paths[stage]}.")
                else:
                    logger.info(
                        f"Updating the {stage} data and saving it in {self.database_paths[stage]}."
                    )
                    self.backup_and_changelog(
                        old_db,
                        new_db,
                        db_diff,
                        stage,
                    )
                self.databases[stage] = new_db
                self.save_database(stage)
        else:
            logger.info("No new items found. No update required.")

    @abstractmethod
    def retrieve_materials(self) -> list[Any]:
        pass

    @abstractmethod
    def save_cif_files(self) -> None:
        pass

    def copy_cif_files(
        self,
        stage: str,
        mute_progress_bars: bool = True,
    ) -> None:
        """
        Copy CIF files from the raw stage to another processing stage.

        This method transfers CIF files from the `raw` stage directory to the specified
        processing stage (`processed` or `final`). It ensures that existing CIF files in
        the target directory are cleared before copying and updates the database with
        the new file paths.

        Args:
            stage (str): The processing stage to copy CIF files to (`processed`, `final`).
            mute_progress_bars (bool, optional): If True, disables progress bars. Defaults to True.

        Raises:
            ValueError: If the stage argument is `raw`, as copying is only allowed from `raw` to other stages.
            MissingData: If the raw CIF directory does not exist or is empty.

        Logs:
            - WARNING: If the target directory is cleaned or CIF files are missing for some materials.
            - ERROR: If a CIF file fails to copy.
            - INFO: When CIF files are successfully copied and the database is updated.
        """
        if stage == "raw":
            logger.error(
                "Stage argument cannot be 'raw'. You can only copy from 'raw' to other stages."
            )
            raise ValueError("Stage argument cannot be 'raw'.")

        source_dir = self.database_directories["raw"] / "structures"
        saving_dir = self.database_directories[stage] / "structures"

        # Clean the target directory if it exists
        if saving_dir.exists():
            logger.warning(f"Cleaning the content in {saving_dir}")
            sh.rmtree(saving_dir)

        # Check if source directory exists and is not empty
        cif_files = {
            file.stem for file in source_dir.glob("*.cif")
        }  # Set of existing CIF filenames
        if not cif_files:
            logger.warning(
                f"The raw CIF directory does not exist or is empty. Check: {source_dir}"
            )
            raise MissingData(
                f"The raw CIF directory does not exist or is empty. Check: {source_dir}"
            )

        # Create the target directory
        saving_dir.mkdir(parents=True, exist_ok=False)

        # Create an index mapping for fast row updates
        db_stage = self.databases[stage].set_index("material_id")
        db_stage["cif_path"] = pd.NA  # Initialize empty column

        missing_ids = []
        for material_id in tqdm(
            self.databases[stage]["material_id"],
            desc=f"Copying materials ('raw' -> '{stage}')",
            disable=mute_progress_bars,
        ):
            if material_id not in cif_files:
                missing_ids.append(material_id)
                continue  # Skip missing files

            source_cif_path = source_dir / f"{material_id}.cif"
            cif_path = saving_dir / f"{material_id}.cif"

            try:
                sh.copy2(source_cif_path, cif_path)
                db_stage.at[material_id, "cif_path"] = str(cif_path)  # Direct assignment

            except Exception as e:
                logger.error(f"Failed to copy CIF for Material ID {material_id}: {e}")
                continue  # Skip to next material instead of stopping execution

        # Restore the updated database index
        self.databases[stage] = db_stage.reset_index()

        # Log missing files once
        if missing_ids:
            logger.warning(f"Missing CIF files for {len(missing_ids)} material IDs.")

        # Save the updated database
        self.save_database(stage)
        logger.info(f"CIF files copied to stage '{stage}' and database updated successfully.")

    def get_database(self, stage: str, subset: str | None = None) -> pd.DataFrame:
        """
        Retrieves the database for a specified processing stage or subset.

        This method returns the database associated with the given `stage`. If the stage
        is `raw`, `processed`, or `final`, it retrieves the corresponding database.
        If `stage` is `interim`, a specific `subset` (e.g., `training`, `validation`, `testing`)
        must be provided.

        Args:
            stage (str): The processing stage to retrieve. Must be one of
                `raw`, `processed`, `final`, or `interim`.
            subset (Optional[str]): The subset to retrieve when `stage` is `interim`.
                Must be one of `training`, `validation`, or `testing`.

        Returns:
            pd.DataFrame: The requested database or subset.

        Raises:
            ValueError: If an invalid `stage` is provided.
            ValueError: If `stage` is `interim` but an invalid `subset` is specified.

        Logs:
            - ERROR: If an invalid `stage` is provided.
            - WARNING: If the retrieved database is empty.
        """
        if stage not in self.processing_stages + ["interim"]:
            logger.error(
                f"Invalid stage: {stage}. Must be one of {self.processing_stages + ['interim']}."
            )
            raise ValueError(f"stage must be one of {self.processing_stages + ['interim']}.")
        if stage in self.processing_stages:
            out_db = self.databases[stage]
        elif stage in ["interim"] and subset is not None:
            out_db = self.subset[subset]
        else:
            raise ValueError(f"subset must be one of {self.interim_sets}.")

        if len(out_db) == 0:
            logger.warning("Empty database found.")
        return out_db

    def load_database(self, stage: str) -> None:
        """
        Loads the existing database for a specified processing stage.

        This method retrieves the database stored in a JSON file for the given
        processing stage (`raw`, `processed`, or `final`). If the file exists,
        it loads the data into a pandas DataFrame. If the file is missing,
        a warning is logged, and an empty DataFrame remains in place.

        Args:
            stage (str): The processing stage to load. Must be one of
                `raw`, `processed`, or `final`.

        Raises:
            ValueError: If `stage` is not one of the predefined processing stages.

        Logs:
            ERROR: If an invalid `stage` is provided.
            DEBUG: If a database is successfully loaded.
            WARNING: If the database file is not found.
        """
        if stage not in self.processing_stages:
            logger.error(f"Invalid stage: {stage}. Must be one of {self.processing_stages}.")
            raise ValueError(f"stage must be one of {self.processing_stages}.")

        db_path = self.database_paths[stage]
        if db_path.exists():
            self.databases[stage] = pd.read_json(db_path)
            if stage == "raw":
                self.databases[stage]["is_specialized"] = self._is_specialized
            logger.debug(f"Loaded existing database from {db_path}.")
        else:
            logger.warning(f"Not found at {db_path}.")

    def _load_interim(
        self, subset: str = "training", model_type: str = "regressor"
    ) -> pd.DataFrame:
        """
        Load the existing interim databases.

        This method attempts to load an interim database corresponding to the specified
        subset and model type. If the database file is found, it is loaded into a pandas
        DataFrame. If not found, a warning is logged, and an empty DataFrame is returned.

        Args:
            subset (str): The subset of the interim dataset to load (`training`, `validation`, `testing`).
            model_type (str, optional): The type of model associated with the data (`regressor`, `classifier`).
                Defaults to "regressor".

        Returns:
            (pd.DataFrame): The loaded database if found, otherwise an empty DataFrame.

        Raises:
            ValueError: If the provided `subset` is not one of the allowed interim sets.

        Logs:
            - ERROR: If an invalid subset is provided.
            - DEBUG: If an existing database is successfully loaded.
            - WARNING: If no database file is found.
        """

        if subset not in self.interim_sets:
            logger.error(f"Invalid set: {subset}. Must be one of {self.interim_sets}.")
            raise ValueError(f"set must be one of {self.interim_sets}.")

        db_name = subset + "_db.json"
        db_path = self.data_dir / "interim" / self.name / model_type / db_name
        if db_path.exists():
            self.subset[subset] = pd.read_json(db_path)
            logger.debug(f"Loaded existing database from {db_path}")
        else:
            logger.warning(f"No existing database found at {db_path}")
        return self.subset[subset]

    def load_regressor_data(self, subset: str = "training"):
        """
        Load the interim dataset for a regression model.

        This method retrieves the specified subset of the interim dataset specifically for
        regression models by internally calling `_load_interim`.

        Args:
            subset (str, optional): The subset of the dataset to load (`training`, `validation`, `testing`).
                Defaults to "training".

        Returns:
            (pd.DataFrame): The loaded regression dataset or an empty DataFrame if not found.

        Raises:
            ValueError: If the provided `subset` is not one of the allowed interim sets.

        Logs:
            - ERROR: If an invalid subset is provided.
            - DEBUG: If an existing database is successfully loaded.
            - WARNING: If no database file is found.
        """
        return self._load_interim(subset=subset, model_type="regressor")

    def load_classifier_data(self, subset: str = "training"):
        """
        Load the interim dataset for a classification model.

        This method retrieves the specified subset of the interim dataset specifically for
        classification models by internally calling `_load_interim`.

        Args:
            subset (str, optional): The subset of the dataset to load (`training`, `testing`).
                Defaults to "training".

        Returns:
            (pd.DataFrame): The loaded regression dataset or an empty DataFrame if not found.

        Raises:
            ValueError: If the provided `subset` is not one of the allowed interim sets.

        Logs:
            - ERROR: If an invalid subset is provided.
            - DEBUG: If an existing database is successfully loaded.
            - WARNING: If no database file is found.
        """
        return self._load_interim(subset=subset, model_type="classifier")

    def load_all(self):
        """
        Loads the databases for all processing stages and subsets.

        This method sequentially loads the databases for all predefined processing
        stages (`raw`, `processed`, `final`). Additionally, it loads both regressor
        and classifier data for all interim subsets (`training`, `validation`, `testing`).

        Calls:
            - `load_database(stage)`: Loads the database for each processing stage.
            - `load_regressor_data(subset)`: Loads regressor-specific data for each subset.
            - `load_classifier_data(subset)`: Loads classifier-specific data for each subset.
        """
        for stage in self.processing_stages:
            self.load_database(stage)
        for subset in self.interim_sets:
            self.load_regressor_data(subset)
            self.load_classifier_data(subset)

    def save_database(self, stage: str) -> None:
        """
        Saves the current state of the database to a JSON file.

        This method serializes the database DataFrame for the specified processing
        stage (`raw`, `processed`, or `final`) and writes it to a JSON file. If an
        existing file is present, it is removed before saving the new version.

        Args:
            stage (str): The processing stage to save. Must be one of
                `raw`, `processed`, or `final`.

        Raises:
            ValueError: If `stage` is not one of the predefined processing stages.
            OSError: If an error occurs while writing the DataFrame to the file.

        Logs:
            - ERROR: If an invalid `stage` is provided.
            - INFO: If the database is successfully saved.
            - ERROR: If the save operation fails.
        """
        if stage not in self.processing_stages:
            logger.error(f"Invalid stage: {stage}. Must be one of {self.processing_stages}.")
            raise ValueError(f"stage must be one of {self.processing_stages}.")

        db_path = self.database_paths[stage]
        if os.path.exists(db_path):
            os.unlink(db_path)
        try:
            self.databases[stage].to_json(db_path)
            logger.info(f"Database saved to {db_path}")
        except Exception as e:
            logger.error(f"Failed to save database to {db_path}: {e}")
            raise OSError(f"Failed to save database to {db_path}: {e}") from e

    def build_reduced_database(
        self,
        size: int,
        new_name: str,
        stage: str,
        seed: int = 42,
    ) -> pd.DataFrame:
        """
        Build a reduced database by sampling random entries from an existing database.

        This method creates a new database by randomly sampling a specified number of entries
        from the given stage (`raw`, `processed`, or `final`) of the existing database. The
        new database is saved and returned as a new instance.

        Args:
            size (int): The number of entries for the reduced database.
            new_name (str): The name for the new reduced database.
            stage (str): The processing stage (`raw`, `processed`, or `final`) to sample from.
            seed (int, optional): The random seed used to generate reproducible samples. Defaults to 42.

        Returns:
            (pd.DataFrame): The reduced database as a pandas DataFrame with the sampled entries.

        Raises:
            ValueError: If `size` is 0, indicating that an empty database is being created.

        Logs:
            ERROR: If the database size is set to 0.
            INFO: If the new reduced database is successfully created and saved.
        """
        new_database = self.__class__(data_dir=self.data_dir, name=new_name)
        if size == 0:
            logger.error("You are creating an empty database.")
            raise ValueError("You are creating an empty database.")
        # copy database
        for stages, db in self.databases.items():
            if len(self.databases[stages]) != 0:
                new_database.databases[stages] = db.copy()
                make_link(self.database_paths[stages], new_database.database_paths[stages])
        for subset, db in self.subset.items():
            if len(self.subset[subset]) != 0:
                new_database.subset[subset] = db.copy()
                subset_path_father = self.data_dir / "interim" / self.name / (subset + "_db.json")
                subset_path_son = (
                    new_database.data_dir / "interim" / new_database.name / (subset + "_db.json")
                )
                if not subset_path_son.parent.exists():
                    subset_path_son.parent.mkdir(parents=True, exist_ok=True)
                make_link(subset_path_father, subset_path_son)

        database = new_database.databases[stage].copy()
        n_all = len(database)
        rng = Generator(PCG64(seed))
        row_i_all = rng.choice(n_all, size, replace=False)
        new_database.databases[stage] = database.iloc[row_i_all, :].reset_index(drop=True)

        new_database.save_database(stage)

        return new_database

    def save_split_db(self, database_dict: dict, model_type: str = "regressor") -> None:
        """
        Saves the split databases (training, validation, testing) into JSON files.

        This method saves the split databases into individual files in the designated
        directory for the given model type (e.g., `regressor`, `classifier`). It checks
        whether the databases are empty before saving. If any of the databases are empty,
        it logs a warning or raises an error depending on the subset (training, validation, testing).

        Args:
            database_dict (dict): A dictionary containing the split databases (`train`,
                `valid`, `test`) as pandas DataFrames.
            model_type (str, optional): The model type for which the splits are being saved
                (e.g., `"regressor"`, `"classifier"`). Defaults to `"regressor"`.

        Returns:
            None

        Raises:
            DatabaseError: If the training dataset is empty
                when attempting to save.

        Logs:
            INFO: When a database is successfully saved to its designated path.
            WARNING: When the validation or testing database is empty.
            ERROR: If any dataset is empty and it is the `train` subset.
        """
        db_path = self.data_dir / "interim" / self.name / model_type
        # if not db_path.exists():
        db_path.mkdir(parents=True, exist_ok=True)

        split_paths = dict(
            train=db_path / "training_db.json",
            valid=db_path / "validation_db.json",
            test=db_path / "testing_db.json",
        )

        for db_name, db_path in split_paths.items():
            if database_dict[db_name].empty:
                if db_name == "valid":
                    logger.warning("No validation database provided.")
                    continue
                elif db_name == "test":
                    logger.warning("No testing database provided.")
                    continue
                else:
                    raise DatabaseError(f"The {db_name} is empty, check the splitting.")
            else:
                database_dict[db_name].to_json(db_path)
                logger.info(f"{db_name} database saved to {db_path}")

    def split_regressor(
        self,
        target_property: str,
        valid_size: float = 0.2,
        test_size: float = 0.05,
        seed: int = 42,
        balance_composition: bool = True,
        save_split: bool = False,
    ) -> None:
        """
        Splits the processed database into training, validation, and test sets for regression tasks.

        This method divides the database into three subsets: training, validation, and test. It
        either performs a random split with or without balancing the frequency of chemical
        species across the splits. If `balance_composition` is True, it ensures that
        elements appear in approximately equal proportions in each subset. The split sizes for
        validation and test sets can be customized.

        Args:
            target_property (str): The property used for the regression task (e.g., a material
                property like "energy").
            valid_size (float, optional): The proportion of the data to use for the validation set.
                Defaults to 0.2.
            test_size (float, optional): The proportion of the data to use for the test set.
                Defaults to 0.05.
            seed (int, optional): The random seed for reproducibility. Defaults to 42.
            balance_composition (bool, optional): Whether to balance the frequency of chemical species
                across the subsets. Defaults to True.
            save_split (bool, optional): Whether to save the resulting splits as files. Defaults to False.

        Returns:
            None

        Raises:
            ValueError: If the sum of `valid_size` and `test_size` exceeds 1.

        Logs:
            INFO: If the dataset is successfully split.
            ERROR: If the sum of `valid_size` and `test_size` is greater than 1.
        """
        if balance_composition:
            db_dict = random_split(
                self.get_database("processed"),
                target_property,
                valid_size=valid_size,
                test_size=test_size,
                seed=seed,
            )
        else:
            dev_size = valid_size + test_size
            if abs(dev_size - test_size) < 1e-8:
                train_, test_ = train_test_split(
                    self.get_database("processed"), test_size=test_size, random_state=seed
                )
                valid_ = pd.DataFrame()
            elif abs(dev_size - valid_size) < 1e-8:
                train_, valid_ = train_test_split(
                    self.get_database("processed"), test_size=valid_size, random_state=seed
                )
                test_ = pd.DataFrame()
            else:
                train_, dev_ = train_test_split(
                    self.get_database("processed"),
                    test_size=valid_size + test_size,
                    random_state=seed,
                )
                valid_, test_ = train_test_split(
                    dev_, test_size=test_size / (valid_size + test_size), random_state=seed
                )

            db_dict = {"train": train_, "valid": valid_, "test": test_}

        if save_split:
            self.save_split_db(db_dict, "regressor")

        self.subset = db_dict

    def split_classifier(
        self,
        test_size: float = 0.2,
        seed: int = 42,
        balance_composition: bool = False,
        save_split: bool = False,
    ) -> None:
        """
        Splits the processed database into training and test sets for classification tasks.

        This method divides the database into two subsets: training and test. It always stratifies
        the split based on the target property (`is_specialized`). If `balance_composition` is True,
        it additionally balances the frequency of chemical species across the training and test sets.
        The size of the test set can be customized with the `test_size` argument.

        Args:
            test_size (float, optional): The proportion of the data to use for the test set.
                Defaults to 0.2.
            seed (int, optional): The random seed for reproducibility. Defaults to 42.
            balance_composition (bool, optional): Whether to balance the frequency of chemical
                species across the training and test sets in addition to stratifying by the
                target property (`is_specialized`). Defaults to False.
            save_split (bool, optional): Whether to save the resulting splits as files. Defaults to False.

        Returns:
            None

        Raises:
            ValueError: If the database is empty or invalid.

        Logs:
            INFO: If the dataset is successfully split.
            ERROR: If the dataset is invalid or empty.
        """
        target_property = "is_specialized"
        if balance_composition:
            db_dict = random_split(
                self.get_database("processed"),
                target_property,
                valid_size=0,
                test_size=test_size,
                seed=seed,
            )
        else:
            train_, test_ = train_test_split(
                self.get_database("processed"), test_size=test_size, random_state=seed
            )
            db_dict = {"train": train_, "test": test_}

        if save_split:
            self.save_split_db(db_dict, "classifier")

        self.subset = db_dict

    def __repr__(self) -> str:
        """
        Text representation of the BaseDatabase instance.
        Used for `print()` and `str()` calls.

        Returns:
            str: ASCII table representation of the database
        """
        # Gather information about each stage
        data = {
            "Stage": [],
            "Entries": [],
            "Last Modified": [],
            "Size": [],
            "Storage Path": [],
        }

        # Calculate column widths
        widths = [10, 8, 17, 10, 55]

        for stage in self.processing_stages:
            # Get database info
            db = self.databases[stage]
            path = self.database_paths[stage]

            # Get file modification time and size if file exists
            if path.exists():
                modified = path.stat().st_mtime
                modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
                size = path.stat().st_size / 1024  # Convert to KB
                size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
            else:
                modified_time = "Not created"
                size_str = "0 KB"

            path_str = str(path.resolve())
            if len(path_str) > widths[4]:
                path_str = ".." + path_str[len(path_str) - widths[4] + 3 :]

            # Append data
            data["Stage"].append(stage.capitalize())
            data["Entries"].append(len(db))
            data["Last Modified"].append(modified_time)
            data["Size"].append(size_str)
            data["Storage Path"].append(path_str)

        # Create DataFrame
        info_df = pd.DataFrame(data)

        # Text representation for terminal/print
        def create_separator(widths):
            return "+" + "+".join("-" * (w + 1) for w in widths) + "+"

        # Create the text representation
        lines = []

        # Add title
        title = f" {self.__class__.__name__} Summary "
        lines.append(f"\n{title:=^{sum(widths) + len(widths) * 2 + 1}}")

        # Add header
        separator = create_separator(widths)
        lines.append(separator)

        header = (
            "|" + "|".join(f" {col:<{widths[i]}}" for i, col in enumerate(info_df.columns)) + "|"
        )
        lines.append(header)
        lines.append(separator)

        # Add data rows
        for _, row in info_df.iterrows():
            line = "|" + "|".join(f" {str(val):<{widths[i]}}" for i, val in enumerate(row)) + "|"
            lines.append(line)

        # Add bottom separator
        lines.append(separator)

        return "\n".join(lines)

    def _repr_html_(self) -> str:
        """
        HTML representation of the BaseDatabase instance.
        Used for Jupyter notebook display.

        Returns:
            str: HTML representation of the database
        """
        # Gather information about each stage
        data = {
            "Stage": [],
            "Entries": [],
            "Last Modified": [],
            "Size": [],
            "Storage Path": [],
        }

        for stage in self.processing_stages:
            # Get database info
            db = self.databases[stage]
            path = self.database_paths[stage]

            # Get file modification time and size if file exists
            if path.exists():
                modified = path.stat().st_mtime
                modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
                size = path.stat().st_size / 1024  # Convert to KB
                size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
            else:
                modified_time = "Not created"
                size_str = "0 KB"

            # Append data
            data["Stage"].append(stage.capitalize())
            data["Entries"].append(len(db))
            data["Last Modified"].append(modified_time)
            data["Size"].append(size_str)
            data["Storage Path"].append(str(path.resolve()))

        # Create DataFrame
        info_df = pd.DataFrame(data)

        # Generate header row
        header_cells = " ".join(
            f'<th style="padding: 12px 15px; text-align: left;">{col}</th>'
            for col in info_df.columns
        )

        # Generate table rows
        table_rows = ""
        for _, row in info_df.iterrows():
            cells = "".join(f'<td style="padding: 12px 15px;">{val}</td>' for val in row)
            table_rows += f"<tr style='border-bottom: 1px solid #e9ecef;'>{cells}</tr>"

        # Create the complete HTML
        html = (
            """<style>
                @media (prefers-color-scheme: dark) {
                    .database-container { background-color: #1e1e1e !important; }
                    .database-title { color: #e0e0e0 !important; }
                    .database-table { background-color: #2d2d2d !important; }
                    .database-header { background-color: #4a4a4a !important; }
                    .database-cell { border-color: #404040 !important; }
                    .database-info { color: #b0b0b0 !important; }
                }
            </style>"""
            '<div style="font-family: Arial, sans-serif; padding: 20px; background:transparent; '
            'border-radius: 8px;">'
            f'<h3 style="color: #58bac7; margin-bottom: 15px;">{self.__class__.__name__}</h3>'
            '<div style="overflow-x: auto;">'
            '<table class="database-table" style="border-collapse: collapse; width: 100%;'
            ' box-shadow: 0 1px 3px rgba(0,0,0,0.1); background:transparent;">'
            # '<table style="border-collapse: collapse; width: 100%; background-color: white; '
            # 'box-shadow: 0 1px 3px rgba(0,0,0,0.1);">'
            "<thead>"
            f'<tr style="background-color: #58bac7; color: white;">{header_cells}</tr>'
            "</thead>"
            f"<tbody>{table_rows}</tbody>"
            "</table>"
            "</div>"
            "</div>"
        )
        return html

__init__(name, data_dir=DATA_DIR)

Sets up the directory structure for storing data across different processing stages (raw, processed, final) and initializes empty Pandas DataFrames for managing the data.

Parameters:

Name Type Description Default
name str

The name of the database instance.

required
data_dir Path

Root directory path for storing database files. Defaults to DATA_DIR.

DATA_DIR
Source code in energy_gnome/dataset/base_dataset.py
Python
def __init__(self, name: str, data_dir: Path = DATA_DIR):
    """
    Initializes the BaseDatabase instance.

    Sets up the directory structure for storing data across different processing
    stages (`raw`, `processed`, `final`) and initializes empty Pandas DataFrames
    for managing the data.

    Args:
        name (str): The name of the database instance.
        data_dir (Path, optional): Root directory path for storing database files.
            Defaults to `DATA_DIR`.
    """
    self.name = name

    if isinstance(data_dir, str):
        self.data_dir = Path(data_dir)
    else:
        self.data_dir = data_dir
    self.data_dir.mkdir(parents=True, exist_ok=True)

    # Define processing stages
    self.processing_stages = ["raw", "processed", "final"]
    self.interim_sets = ["training", "validation", "testing"]

    # Initialize directories, paths, and databases for each stage
    self.database_directories = {
        stage: self.data_dir / stage / self.name for stage in self.processing_stages
    }
    for stage_dir in self.database_directories.values():
        stage_dir.mkdir(parents=True, exist_ok=True)

    self.database_paths = {
        stage: dir_path / "database.json"
        for stage, dir_path in self.database_directories.items()
    }

    self.databases = {stage: pd.DataFrame() for stage in self.processing_stages}
    self.subset = {subset: pd.DataFrame() for subset in self.interim_sets}
    self._update_raw = False
    self._is_specialized = False
    self._set_is_specialized()
    self.load_all()

__repr__()

Text representation of the BaseDatabase instance. Used for print() and str() calls.

Returns:

Name Type Description
str str

ASCII table representation of the database

Source code in energy_gnome/dataset/base_dataset.py
Python
def __repr__(self) -> str:
    """
    Text representation of the BaseDatabase instance.
    Used for `print()` and `str()` calls.

    Returns:
        str: ASCII table representation of the database
    """
    # Gather information about each stage
    data = {
        "Stage": [],
        "Entries": [],
        "Last Modified": [],
        "Size": [],
        "Storage Path": [],
    }

    # Calculate column widths
    widths = [10, 8, 17, 10, 55]

    for stage in self.processing_stages:
        # Get database info
        db = self.databases[stage]
        path = self.database_paths[stage]

        # Get file modification time and size if file exists
        if path.exists():
            modified = path.stat().st_mtime
            modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
            size = path.stat().st_size / 1024  # Convert to KB
            size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
        else:
            modified_time = "Not created"
            size_str = "0 KB"

        path_str = str(path.resolve())
        if len(path_str) > widths[4]:
            path_str = ".." + path_str[len(path_str) - widths[4] + 3 :]

        # Append data
        data["Stage"].append(stage.capitalize())
        data["Entries"].append(len(db))
        data["Last Modified"].append(modified_time)
        data["Size"].append(size_str)
        data["Storage Path"].append(path_str)

    # Create DataFrame
    info_df = pd.DataFrame(data)

    # Text representation for terminal/print
    def create_separator(widths):
        return "+" + "+".join("-" * (w + 1) for w in widths) + "+"

    # Create the text representation
    lines = []

    # Add title
    title = f" {self.__class__.__name__} Summary "
    lines.append(f"\n{title:=^{sum(widths) + len(widths) * 2 + 1}}")

    # Add header
    separator = create_separator(widths)
    lines.append(separator)

    header = (
        "|" + "|".join(f" {col:<{widths[i]}}" for i, col in enumerate(info_df.columns)) + "|"
    )
    lines.append(header)
    lines.append(separator)

    # Add data rows
    for _, row in info_df.iterrows():
        line = "|" + "|".join(f" {str(val):<{widths[i]}}" for i, val in enumerate(row)) + "|"
        lines.append(line)

    # Add bottom separator
    lines.append(separator)

    return "\n".join(lines)

allow_raw_update()

Enables modifications to the raw data stage.

This method sets the internal flag _update_raw to True, allowing changes to be made to the raw data without raising an exception.

Warning

Use with caution, as modifying raw data can impact data integrity.

Source code in energy_gnome/dataset/base_dataset.py
Python
def allow_raw_update(self):
    """
    Enables modifications to the `raw` data stage.

    This method sets the internal flag `_update_raw` to `True`, allowing changes
    to be made to the raw data without raising an exception.

    Warning:
        Use with caution, as modifying raw data can impact data integrity.
    """
    self._update_raw = True

backup_and_changelog(old_db, new_db, differences, stage)

Backup the old database and update the changelog with identified differences.

This method saves a backup of the existing database before updating it with new data. It also logs the changes detected by comparing the old and new databases, storing the details in a changelog file.

Parameters:

Name Type Description Default
old_db DataFrame

The existing database before updating.

required
new_db DataFrame

The new database with updated entries.

required
differences Series

A series containing the material IDs of entries that differ.

required
stage str

The processing stage ("raw", "processed", "final") for which the backup and changelog are being maintained.

required

Raises:

Type Description
ValueError

If an invalid stage is provided.

OSError

If there is an issue writing to the backup or changelog files.

Logs
  • ERROR: If an invalid stage is provided.
  • DEBUG: When the old database is successfully backed up.
  • ERROR: If the backup process fails.
  • DEBUG: When the changelog is successfully updated with differences.
  • ERROR: If updating the changelog fails.
Source code in energy_gnome/dataset/base_dataset.py
Python
def backup_and_changelog(
    self,
    old_db: pd.DataFrame,
    new_db: pd.DataFrame,
    differences: pd.Series,
    stage: str,
) -> None:
    """
    Backup the old database and update the changelog with identified differences.

    This method saves a backup of the existing database before updating it with new data.
    It also logs the changes detected by comparing the old and new databases, storing the
    details in a changelog file.

    Args:
        old_db (pd.DataFrame): The existing database before updating.
        new_db (pd.DataFrame): The new database with updated entries.
        differences (pd.Series): A series containing the material IDs of entries that differ.
        stage (str): The processing stage ("raw", "processed", "final") for which the backup
                    and changelog are being maintained.

    Raises:
        ValueError: If an invalid `stage` is provided.
        OSError: If there is an issue writing to the backup or changelog files.

    Logs:
        - ERROR: If an invalid stage is provided.
        - DEBUG: When the old database is successfully backed up.
        - ERROR: If the backup process fails.
        - DEBUG: When the changelog is successfully updated with differences.
        - ERROR: If updating the changelog fails.
    """
    if stage not in self.processing_stages:
        logger.error(f"Invalid stage: {stage}. Must be one of {self.processing_stages}.")
        raise ValueError(f"stage must be one of {self.processing_stages}.")

    # Backup the old database
    backup_path = self.database_directories[stage] / "old_database.json"
    try:
        old_db.to_json(backup_path)
        logger.debug(f"Old database backed up to {backup_path}")
    except Exception as e:
        logger.error(f"Failed to backup old database to {backup_path}: {e}")
        raise OSError(f"Failed to backup old database to {backup_path}: {e}") from e

    # Prepare changelog
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    changelog_path = self.database_directories[stage] / "changelog.txt"

    header = (
        f"= Change Log - {timestamp} ".ljust(70, "=") + "\n"
        "Difference old_database.json VS database.json\n"
        f"{'ID':<15}{'Formula':<30}{'Last Updated':<25}\n" + "-" * 70 + "\n"
    )

    # Set index for faster lookup
    new_db_indexed = new_db.set_index("material_id")

    # Process differences efficiently
    changes = [
        f"{identifier:<15}{new_db_indexed.at[identifier, 'formula_pretty'] if identifier in new_db_indexed.index else 'N/A':<30}"
        f"{new_db_indexed.at[identifier, 'last_updated'] if identifier in new_db_indexed.index else 'N/A':<25}\n"
        for identifier in differences["material_id"]
    ]

    try:
        with open(changelog_path, "a") as file:
            file.write(header + "".join(changes))
        logger.debug(f"Changelog updated at {changelog_path} with {len(differences)} changes.")
    except Exception as e:
        logger.error(f"Failed to update changelog at {changelog_path}: {e}")
        raise OSError(f"Failed to update changelog at {changelog_path}: {e}") from e

build_reduced_database(size, new_name, stage, seed=42)

Build a reduced database by sampling random entries from an existing database.

This method creates a new database by randomly sampling a specified number of entries from the given stage (raw, processed, or final) of the existing database. The new database is saved and returned as a new instance.

Parameters:

Name Type Description Default
size int

The number of entries for the reduced database.

required
new_name str

The name for the new reduced database.

required
stage str

The processing stage (raw, processed, or final) to sample from.

required
seed int

The random seed used to generate reproducible samples. Defaults to 42.

42

Returns:

Type Description
DataFrame

The reduced database as a pandas DataFrame with the sampled entries.

Raises:

Type Description
ValueError

If size is 0, indicating that an empty database is being created.

Logs

ERROR: If the database size is set to 0. INFO: If the new reduced database is successfully created and saved.

Source code in energy_gnome/dataset/base_dataset.py
Python
def build_reduced_database(
    self,
    size: int,
    new_name: str,
    stage: str,
    seed: int = 42,
) -> pd.DataFrame:
    """
    Build a reduced database by sampling random entries from an existing database.

    This method creates a new database by randomly sampling a specified number of entries
    from the given stage (`raw`, `processed`, or `final`) of the existing database. The
    new database is saved and returned as a new instance.

    Args:
        size (int): The number of entries for the reduced database.
        new_name (str): The name for the new reduced database.
        stage (str): The processing stage (`raw`, `processed`, or `final`) to sample from.
        seed (int, optional): The random seed used to generate reproducible samples. Defaults to 42.

    Returns:
        (pd.DataFrame): The reduced database as a pandas DataFrame with the sampled entries.

    Raises:
        ValueError: If `size` is 0, indicating that an empty database is being created.

    Logs:
        ERROR: If the database size is set to 0.
        INFO: If the new reduced database is successfully created and saved.
    """
    new_database = self.__class__(data_dir=self.data_dir, name=new_name)
    if size == 0:
        logger.error("You are creating an empty database.")
        raise ValueError("You are creating an empty database.")
    # copy database
    for stages, db in self.databases.items():
        if len(self.databases[stages]) != 0:
            new_database.databases[stages] = db.copy()
            make_link(self.database_paths[stages], new_database.database_paths[stages])
    for subset, db in self.subset.items():
        if len(self.subset[subset]) != 0:
            new_database.subset[subset] = db.copy()
            subset_path_father = self.data_dir / "interim" / self.name / (subset + "_db.json")
            subset_path_son = (
                new_database.data_dir / "interim" / new_database.name / (subset + "_db.json")
            )
            if not subset_path_son.parent.exists():
                subset_path_son.parent.mkdir(parents=True, exist_ok=True)
            make_link(subset_path_father, subset_path_son)

    database = new_database.databases[stage].copy()
    n_all = len(database)
    rng = Generator(PCG64(seed))
    row_i_all = rng.choice(n_all, size, replace=False)
    new_database.databases[stage] = database.iloc[row_i_all, :].reset_index(drop=True)

    new_database.save_database(stage)

    return new_database

compare_and_update(new_db, stage)

Compare and update the database with new entries.

This method checks for new entries in the provided database and updates the stored database accordingly. It ensures that raw data remains immutable unless explicitly allowed. If new entries are found, the old database is backed up, and a changelog is created.

Parameters:

Name Type Description Default
new_db DataFrame

The new database to compare against the existing one.

required
stage str

The processing stage ("raw", "processed", "final").

required

Returns:

Type Description
DataFrame

pd.DataFrame: The updated database containing new entries.

Raises:

Type Description
ImmutableRawDataError

If attempting to modify raw data without explicit permission.

Logs
  • WARNING: If new items are detected in the database.
  • ERROR: If an attempt is made to modify immutable raw data.
  • INFO: When saving or updating the database.
  • INFO: If no new items are found and no update is required.
Source code in energy_gnome/dataset/base_dataset.py
Python
def compare_and_update(self, new_db: pd.DataFrame, stage: str) -> pd.DataFrame:
    """
    Compare and update the database with new entries.

    This method checks for new entries in the provided database and updates the stored
    database accordingly. It ensures that raw data remains immutable unless explicitly
    allowed. If new entries are found, the old database is backed up, and a changelog
    is created.

    Args:
        new_db (pd.DataFrame): The new database to compare against the existing one.
        stage (str): The processing stage ("raw", "processed", "final").

    Returns:
        pd.DataFrame: The updated database containing new entries.

    Raises:
        ImmutableRawDataError: If attempting to modify raw data without explicit permission.

    Logs:
        - WARNING: If new items are detected in the database.
        - ERROR: If an attempt is made to modify immutable raw data.
        - INFO: When saving or updating the database.
        - INFO: If no new items are found and no update is required.
    """
    old_db = self.get_database(stage=stage)
    db_diff = self.compare_databases(new_db, stage)
    if not db_diff.empty:
        logger.warning(f"The new database contains {len(db_diff)} new items.")

        if stage == "raw" and not self._update_raw:
            logger.error("Raw data must be treated as immutable!")
            logger.error(
                "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
            )
            raise ImmutableRawDataError(
                "Raw data must be treated as immutable!\n"
                "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
            )
        else:
            if stage == "raw":
                logger.info(
                    "Be careful you are changing the raw data which must be treated as immutable!"
                )
            if old_db.empty:
                logger.info(f"Saving new {stage} data in {self.database_paths[stage]}.")
            else:
                logger.info(
                    f"Updating the {stage} data and saving it in {self.database_paths[stage]}."
                )
                self.backup_and_changelog(
                    old_db,
                    new_db,
                    db_diff,
                    stage,
                )
            self.databases[stage] = new_db
            self.save_database(stage)
    else:
        logger.info("No new items found. No update required.")

compare_databases(new_db, stage)

Compare two databases and identify new entry IDs.

This method compares an existing database (loaded from the specified stage) with a new database. It returns the entries from the new database that are not present in the existing one.

Parameters:

Name Type Description Default
new_db DataFrame

New database to compare.

required
stage str

Processing stage ("raw", "processed", "final").

required

Returns:

Type Description
DataFrame

pd.DataFrame: Subset of new_db containing only new entry IDs.

Logs
  • DEBUG: The number of new entries found.
  • WARNING: If the old database is empty and nothing can be compared.
Source code in energy_gnome/dataset/base_dataset.py
Python
def compare_databases(self, new_db: pd.DataFrame, stage: str) -> pd.DataFrame:
    """
    Compare two databases and identify new entry IDs.

    This method compares an existing database (loaded from the specified stage) with a new database.
    It returns the entries from the new database that are not present in the existing one.

    Args:
        new_db (pd.DataFrame): New database to compare.
        stage (str): Processing stage ("raw", "processed", "final").

    Returns:
        pd.DataFrame: Subset of `new_db` containing only new entry IDs.

    Logs:
        - DEBUG: The number of new entries found.
        - WARNING: If the old database is empty and nothing can be compared.
    """
    old_db = self.get_database(stage=stage)
    if not old_db.empty:
        new_ids_set = set(new_db["material_id"])
        old_ids_set = set(old_db["material_id"])
        new_ids_only = new_ids_set - old_ids_set
        logger.debug(f"Found {len(new_ids_only)} new IDs in the new database.")
        return new_db[new_db["material_id"].isin(new_ids_only)]
    else:
        logger.warning("Nothing to compare here...")
        return new_db

copy_cif_files(stage, mute_progress_bars=True)

Copy CIF files from the raw stage to another processing stage.

This method transfers CIF files from the raw stage directory to the specified processing stage (processed or final). It ensures that existing CIF files in the target directory are cleared before copying and updates the database with the new file paths.

Parameters:

Name Type Description Default
stage str

The processing stage to copy CIF files to (processed, final).

required
mute_progress_bars bool

If True, disables progress bars. Defaults to True.

True

Raises:

Type Description
ValueError

If the stage argument is raw, as copying is only allowed from raw to other stages.

MissingData

If the raw CIF directory does not exist or is empty.

Logs
  • WARNING: If the target directory is cleaned or CIF files are missing for some materials.
  • ERROR: If a CIF file fails to copy.
  • INFO: When CIF files are successfully copied and the database is updated.
Source code in energy_gnome/dataset/base_dataset.py
Python
def copy_cif_files(
    self,
    stage: str,
    mute_progress_bars: bool = True,
) -> None:
    """
    Copy CIF files from the raw stage to another processing stage.

    This method transfers CIF files from the `raw` stage directory to the specified
    processing stage (`processed` or `final`). It ensures that existing CIF files in
    the target directory are cleared before copying and updates the database with
    the new file paths.

    Args:
        stage (str): The processing stage to copy CIF files to (`processed`, `final`).
        mute_progress_bars (bool, optional): If True, disables progress bars. Defaults to True.

    Raises:
        ValueError: If the stage argument is `raw`, as copying is only allowed from `raw` to other stages.
        MissingData: If the raw CIF directory does not exist or is empty.

    Logs:
        - WARNING: If the target directory is cleaned or CIF files are missing for some materials.
        - ERROR: If a CIF file fails to copy.
        - INFO: When CIF files are successfully copied and the database is updated.
    """
    if stage == "raw":
        logger.error(
            "Stage argument cannot be 'raw'. You can only copy from 'raw' to other stages."
        )
        raise ValueError("Stage argument cannot be 'raw'.")

    source_dir = self.database_directories["raw"] / "structures"
    saving_dir = self.database_directories[stage] / "structures"

    # Clean the target directory if it exists
    if saving_dir.exists():
        logger.warning(f"Cleaning the content in {saving_dir}")
        sh.rmtree(saving_dir)

    # Check if source directory exists and is not empty
    cif_files = {
        file.stem for file in source_dir.glob("*.cif")
    }  # Set of existing CIF filenames
    if not cif_files:
        logger.warning(
            f"The raw CIF directory does not exist or is empty. Check: {source_dir}"
        )
        raise MissingData(
            f"The raw CIF directory does not exist or is empty. Check: {source_dir}"
        )

    # Create the target directory
    saving_dir.mkdir(parents=True, exist_ok=False)

    # Create an index mapping for fast row updates
    db_stage = self.databases[stage].set_index("material_id")
    db_stage["cif_path"] = pd.NA  # Initialize empty column

    missing_ids = []
    for material_id in tqdm(
        self.databases[stage]["material_id"],
        desc=f"Copying materials ('raw' -> '{stage}')",
        disable=mute_progress_bars,
    ):
        if material_id not in cif_files:
            missing_ids.append(material_id)
            continue  # Skip missing files

        source_cif_path = source_dir / f"{material_id}.cif"
        cif_path = saving_dir / f"{material_id}.cif"

        try:
            sh.copy2(source_cif_path, cif_path)
            db_stage.at[material_id, "cif_path"] = str(cif_path)  # Direct assignment

        except Exception as e:
            logger.error(f"Failed to copy CIF for Material ID {material_id}: {e}")
            continue  # Skip to next material instead of stopping execution

    # Restore the updated database index
    self.databases[stage] = db_stage.reset_index()

    # Log missing files once
    if missing_ids:
        logger.warning(f"Missing CIF files for {len(missing_ids)} material IDs.")

    # Save the updated database
    self.save_database(stage)
    logger.info(f"CIF files copied to stage '{stage}' and database updated successfully.")

get_database(stage, subset=None)

Retrieves the database for a specified processing stage or subset.

This method returns the database associated with the given stage. If the stage is raw, processed, or final, it retrieves the corresponding database. If stage is interim, a specific subset (e.g., training, validation, testing) must be provided.

Parameters:

Name Type Description Default
stage str

The processing stage to retrieve. Must be one of raw, processed, final, or interim.

required
subset Optional[str]

The subset to retrieve when stage is interim. Must be one of training, validation, or testing.

None

Returns:

Type Description
DataFrame

pd.DataFrame: The requested database or subset.

Raises:

Type Description
ValueError

If an invalid stage is provided.

ValueError

If stage is interim but an invalid subset is specified.

Logs
  • ERROR: If an invalid stage is provided.
  • WARNING: If the retrieved database is empty.
Source code in energy_gnome/dataset/base_dataset.py
Python
def get_database(self, stage: str, subset: str | None = None) -> pd.DataFrame:
    """
    Retrieves the database for a specified processing stage or subset.

    This method returns the database associated with the given `stage`. If the stage
    is `raw`, `processed`, or `final`, it retrieves the corresponding database.
    If `stage` is `interim`, a specific `subset` (e.g., `training`, `validation`, `testing`)
    must be provided.

    Args:
        stage (str): The processing stage to retrieve. Must be one of
            `raw`, `processed`, `final`, or `interim`.
        subset (Optional[str]): The subset to retrieve when `stage` is `interim`.
            Must be one of `training`, `validation`, or `testing`.

    Returns:
        pd.DataFrame: The requested database or subset.

    Raises:
        ValueError: If an invalid `stage` is provided.
        ValueError: If `stage` is `interim` but an invalid `subset` is specified.

    Logs:
        - ERROR: If an invalid `stage` is provided.
        - WARNING: If the retrieved database is empty.
    """
    if stage not in self.processing_stages + ["interim"]:
        logger.error(
            f"Invalid stage: {stage}. Must be one of {self.processing_stages + ['interim']}."
        )
        raise ValueError(f"stage must be one of {self.processing_stages + ['interim']}.")
    if stage in self.processing_stages:
        out_db = self.databases[stage]
    elif stage in ["interim"] and subset is not None:
        out_db = self.subset[subset]
    else:
        raise ValueError(f"subset must be one of {self.interim_sets}.")

    if len(out_db) == 0:
        logger.warning("Empty database found.")
    return out_db

load_all()

Loads the databases for all processing stages and subsets.

This method sequentially loads the databases for all predefined processing stages (raw, processed, final). Additionally, it loads both regressor and classifier data for all interim subsets (training, validation, testing).

Calls
  • load_database(stage): Loads the database for each processing stage.
  • load_regressor_data(subset): Loads regressor-specific data for each subset.
  • load_classifier_data(subset): Loads classifier-specific data for each subset.
Source code in energy_gnome/dataset/base_dataset.py
Python
def load_all(self):
    """
    Loads the databases for all processing stages and subsets.

    This method sequentially loads the databases for all predefined processing
    stages (`raw`, `processed`, `final`). Additionally, it loads both regressor
    and classifier data for all interim subsets (`training`, `validation`, `testing`).

    Calls:
        - `load_database(stage)`: Loads the database for each processing stage.
        - `load_regressor_data(subset)`: Loads regressor-specific data for each subset.
        - `load_classifier_data(subset)`: Loads classifier-specific data for each subset.
    """
    for stage in self.processing_stages:
        self.load_database(stage)
    for subset in self.interim_sets:
        self.load_regressor_data(subset)
        self.load_classifier_data(subset)

load_classifier_data(subset='training')

Load the interim dataset for a classification model.

This method retrieves the specified subset of the interim dataset specifically for classification models by internally calling _load_interim.

Parameters:

Name Type Description Default
subset str

The subset of the dataset to load (training, testing). Defaults to "training".

'training'

Returns:

Type Description
DataFrame

The loaded regression dataset or an empty DataFrame if not found.

Raises:

Type Description
ValueError

If the provided subset is not one of the allowed interim sets.

Logs
  • ERROR: If an invalid subset is provided.
  • DEBUG: If an existing database is successfully loaded.
  • WARNING: If no database file is found.
Source code in energy_gnome/dataset/base_dataset.py
Python
def load_classifier_data(self, subset: str = "training"):
    """
    Load the interim dataset for a classification model.

    This method retrieves the specified subset of the interim dataset specifically for
    classification models by internally calling `_load_interim`.

    Args:
        subset (str, optional): The subset of the dataset to load (`training`, `testing`).
            Defaults to "training".

    Returns:
        (pd.DataFrame): The loaded regression dataset or an empty DataFrame if not found.

    Raises:
        ValueError: If the provided `subset` is not one of the allowed interim sets.

    Logs:
        - ERROR: If an invalid subset is provided.
        - DEBUG: If an existing database is successfully loaded.
        - WARNING: If no database file is found.
    """
    return self._load_interim(subset=subset, model_type="classifier")

load_database(stage)

Loads the existing database for a specified processing stage.

This method retrieves the database stored in a JSON file for the given processing stage (raw, processed, or final). If the file exists, it loads the data into a pandas DataFrame. If the file is missing, a warning is logged, and an empty DataFrame remains in place.

Parameters:

Name Type Description Default
stage str

The processing stage to load. Must be one of raw, processed, or final.

required

Raises:

Type Description
ValueError

If stage is not one of the predefined processing stages.

Logs

ERROR: If an invalid stage is provided. DEBUG: If a database is successfully loaded. WARNING: If the database file is not found.

Source code in energy_gnome/dataset/base_dataset.py
Python
def load_database(self, stage: str) -> None:
    """
    Loads the existing database for a specified processing stage.

    This method retrieves the database stored in a JSON file for the given
    processing stage (`raw`, `processed`, or `final`). If the file exists,
    it loads the data into a pandas DataFrame. If the file is missing,
    a warning is logged, and an empty DataFrame remains in place.

    Args:
        stage (str): The processing stage to load. Must be one of
            `raw`, `processed`, or `final`.

    Raises:
        ValueError: If `stage` is not one of the predefined processing stages.

    Logs:
        ERROR: If an invalid `stage` is provided.
        DEBUG: If a database is successfully loaded.
        WARNING: If the database file is not found.
    """
    if stage not in self.processing_stages:
        logger.error(f"Invalid stage: {stage}. Must be one of {self.processing_stages}.")
        raise ValueError(f"stage must be one of {self.processing_stages}.")

    db_path = self.database_paths[stage]
    if db_path.exists():
        self.databases[stage] = pd.read_json(db_path)
        if stage == "raw":
            self.databases[stage]["is_specialized"] = self._is_specialized
        logger.debug(f"Loaded existing database from {db_path}.")
    else:
        logger.warning(f"Not found at {db_path}.")

load_regressor_data(subset='training')

Load the interim dataset for a regression model.

This method retrieves the specified subset of the interim dataset specifically for regression models by internally calling _load_interim.

Parameters:

Name Type Description Default
subset str

The subset of the dataset to load (training, validation, testing). Defaults to "training".

'training'

Returns:

Type Description
DataFrame

The loaded regression dataset or an empty DataFrame if not found.

Raises:

Type Description
ValueError

If the provided subset is not one of the allowed interim sets.

Logs
  • ERROR: If an invalid subset is provided.
  • DEBUG: If an existing database is successfully loaded.
  • WARNING: If no database file is found.
Source code in energy_gnome/dataset/base_dataset.py
Python
def load_regressor_data(self, subset: str = "training"):
    """
    Load the interim dataset for a regression model.

    This method retrieves the specified subset of the interim dataset specifically for
    regression models by internally calling `_load_interim`.

    Args:
        subset (str, optional): The subset of the dataset to load (`training`, `validation`, `testing`).
            Defaults to "training".

    Returns:
        (pd.DataFrame): The loaded regression dataset or an empty DataFrame if not found.

    Raises:
        ValueError: If the provided `subset` is not one of the allowed interim sets.

    Logs:
        - ERROR: If an invalid subset is provided.
        - DEBUG: If an existing database is successfully loaded.
        - WARNING: If no database file is found.
    """
    return self._load_interim(subset=subset, model_type="regressor")

save_database(stage)

Saves the current state of the database to a JSON file.

This method serializes the database DataFrame for the specified processing stage (raw, processed, or final) and writes it to a JSON file. If an existing file is present, it is removed before saving the new version.

Parameters:

Name Type Description Default
stage str

The processing stage to save. Must be one of raw, processed, or final.

required

Raises:

Type Description
ValueError

If stage is not one of the predefined processing stages.

OSError

If an error occurs while writing the DataFrame to the file.

Logs
  • ERROR: If an invalid stage is provided.
  • INFO: If the database is successfully saved.
  • ERROR: If the save operation fails.
Source code in energy_gnome/dataset/base_dataset.py
Python
def save_database(self, stage: str) -> None:
    """
    Saves the current state of the database to a JSON file.

    This method serializes the database DataFrame for the specified processing
    stage (`raw`, `processed`, or `final`) and writes it to a JSON file. If an
    existing file is present, it is removed before saving the new version.

    Args:
        stage (str): The processing stage to save. Must be one of
            `raw`, `processed`, or `final`.

    Raises:
        ValueError: If `stage` is not one of the predefined processing stages.
        OSError: If an error occurs while writing the DataFrame to the file.

    Logs:
        - ERROR: If an invalid `stage` is provided.
        - INFO: If the database is successfully saved.
        - ERROR: If the save operation fails.
    """
    if stage not in self.processing_stages:
        logger.error(f"Invalid stage: {stage}. Must be one of {self.processing_stages}.")
        raise ValueError(f"stage must be one of {self.processing_stages}.")

    db_path = self.database_paths[stage]
    if os.path.exists(db_path):
        os.unlink(db_path)
    try:
        self.databases[stage].to_json(db_path)
        logger.info(f"Database saved to {db_path}")
    except Exception as e:
        logger.error(f"Failed to save database to {db_path}: {e}")
        raise OSError(f"Failed to save database to {db_path}: {e}") from e

save_split_db(database_dict, model_type='regressor')

Saves the split databases (training, validation, testing) into JSON files.

This method saves the split databases into individual files in the designated directory for the given model type (e.g., regressor, classifier). It checks whether the databases are empty before saving. If any of the databases are empty, it logs a warning or raises an error depending on the subset (training, validation, testing).

Parameters:

Name Type Description Default
database_dict dict

A dictionary containing the split databases (train, valid, test) as pandas DataFrames.

required
model_type str

The model type for which the splits are being saved (e.g., "regressor", "classifier"). Defaults to "regressor".

'regressor'

Returns:

Type Description
None

None

Raises:

Type Description
DatabaseError

If the training dataset is empty when attempting to save.

Logs

INFO: When a database is successfully saved to its designated path. WARNING: When the validation or testing database is empty. ERROR: If any dataset is empty and it is the train subset.

Source code in energy_gnome/dataset/base_dataset.py
Python
def save_split_db(self, database_dict: dict, model_type: str = "regressor") -> None:
    """
    Saves the split databases (training, validation, testing) into JSON files.

    This method saves the split databases into individual files in the designated
    directory for the given model type (e.g., `regressor`, `classifier`). It checks
    whether the databases are empty before saving. If any of the databases are empty,
    it logs a warning or raises an error depending on the subset (training, validation, testing).

    Args:
        database_dict (dict): A dictionary containing the split databases (`train`,
            `valid`, `test`) as pandas DataFrames.
        model_type (str, optional): The model type for which the splits are being saved
            (e.g., `"regressor"`, `"classifier"`). Defaults to `"regressor"`.

    Returns:
        None

    Raises:
        DatabaseError: If the training dataset is empty
            when attempting to save.

    Logs:
        INFO: When a database is successfully saved to its designated path.
        WARNING: When the validation or testing database is empty.
        ERROR: If any dataset is empty and it is the `train` subset.
    """
    db_path = self.data_dir / "interim" / self.name / model_type
    # if not db_path.exists():
    db_path.mkdir(parents=True, exist_ok=True)

    split_paths = dict(
        train=db_path / "training_db.json",
        valid=db_path / "validation_db.json",
        test=db_path / "testing_db.json",
    )

    for db_name, db_path in split_paths.items():
        if database_dict[db_name].empty:
            if db_name == "valid":
                logger.warning("No validation database provided.")
                continue
            elif db_name == "test":
                logger.warning("No testing database provided.")
                continue
            else:
                raise DatabaseError(f"The {db_name} is empty, check the splitting.")
        else:
            database_dict[db_name].to_json(db_path)
            logger.info(f"{db_name} database saved to {db_path}")

split_classifier(test_size=0.2, seed=42, balance_composition=False, save_split=False)

Splits the processed database into training and test sets for classification tasks.

This method divides the database into two subsets: training and test. It always stratifies the split based on the target property (is_specialized). If balance_composition is True, it additionally balances the frequency of chemical species across the training and test sets. The size of the test set can be customized with the test_size argument.

Parameters:

Name Type Description Default
test_size float

The proportion of the data to use for the test set. Defaults to 0.2.

0.2
seed int

The random seed for reproducibility. Defaults to 42.

42
balance_composition bool

Whether to balance the frequency of chemical species across the training and test sets in addition to stratifying by the target property (is_specialized). Defaults to False.

False
save_split bool

Whether to save the resulting splits as files. Defaults to False.

False

Returns:

Type Description
None

None

Raises:

Type Description
ValueError

If the database is empty or invalid.

Logs

INFO: If the dataset is successfully split. ERROR: If the dataset is invalid or empty.

Source code in energy_gnome/dataset/base_dataset.py
Python
def split_classifier(
    self,
    test_size: float = 0.2,
    seed: int = 42,
    balance_composition: bool = False,
    save_split: bool = False,
) -> None:
    """
    Splits the processed database into training and test sets for classification tasks.

    This method divides the database into two subsets: training and test. It always stratifies
    the split based on the target property (`is_specialized`). If `balance_composition` is True,
    it additionally balances the frequency of chemical species across the training and test sets.
    The size of the test set can be customized with the `test_size` argument.

    Args:
        test_size (float, optional): The proportion of the data to use for the test set.
            Defaults to 0.2.
        seed (int, optional): The random seed for reproducibility. Defaults to 42.
        balance_composition (bool, optional): Whether to balance the frequency of chemical
            species across the training and test sets in addition to stratifying by the
            target property (`is_specialized`). Defaults to False.
        save_split (bool, optional): Whether to save the resulting splits as files. Defaults to False.

    Returns:
        None

    Raises:
        ValueError: If the database is empty or invalid.

    Logs:
        INFO: If the dataset is successfully split.
        ERROR: If the dataset is invalid or empty.
    """
    target_property = "is_specialized"
    if balance_composition:
        db_dict = random_split(
            self.get_database("processed"),
            target_property,
            valid_size=0,
            test_size=test_size,
            seed=seed,
        )
    else:
        train_, test_ = train_test_split(
            self.get_database("processed"), test_size=test_size, random_state=seed
        )
        db_dict = {"train": train_, "test": test_}

    if save_split:
        self.save_split_db(db_dict, "classifier")

    self.subset = db_dict

split_regressor(target_property, valid_size=0.2, test_size=0.05, seed=42, balance_composition=True, save_split=False)

Splits the processed database into training, validation, and test sets for regression tasks.

This method divides the database into three subsets: training, validation, and test. It either performs a random split with or without balancing the frequency of chemical species across the splits. If balance_composition is True, it ensures that elements appear in approximately equal proportions in each subset. The split sizes for validation and test sets can be customized.

Parameters:

Name Type Description Default
target_property str

The property used for the regression task (e.g., a material property like "energy").

required
valid_size float

The proportion of the data to use for the validation set. Defaults to 0.2.

0.2
test_size float

The proportion of the data to use for the test set. Defaults to 0.05.

0.05
seed int

The random seed for reproducibility. Defaults to 42.

42
balance_composition bool

Whether to balance the frequency of chemical species across the subsets. Defaults to True.

True
save_split bool

Whether to save the resulting splits as files. Defaults to False.

False

Returns:

Type Description
None

None

Raises:

Type Description
ValueError

If the sum of valid_size and test_size exceeds 1.

Logs

INFO: If the dataset is successfully split. ERROR: If the sum of valid_size and test_size is greater than 1.

Source code in energy_gnome/dataset/base_dataset.py
Python
def split_regressor(
    self,
    target_property: str,
    valid_size: float = 0.2,
    test_size: float = 0.05,
    seed: int = 42,
    balance_composition: bool = True,
    save_split: bool = False,
) -> None:
    """
    Splits the processed database into training, validation, and test sets for regression tasks.

    This method divides the database into three subsets: training, validation, and test. It
    either performs a random split with or without balancing the frequency of chemical
    species across the splits. If `balance_composition` is True, it ensures that
    elements appear in approximately equal proportions in each subset. The split sizes for
    validation and test sets can be customized.

    Args:
        target_property (str): The property used for the regression task (e.g., a material
            property like "energy").
        valid_size (float, optional): The proportion of the data to use for the validation set.
            Defaults to 0.2.
        test_size (float, optional): The proportion of the data to use for the test set.
            Defaults to 0.05.
        seed (int, optional): The random seed for reproducibility. Defaults to 42.
        balance_composition (bool, optional): Whether to balance the frequency of chemical species
            across the subsets. Defaults to True.
        save_split (bool, optional): Whether to save the resulting splits as files. Defaults to False.

    Returns:
        None

    Raises:
        ValueError: If the sum of `valid_size` and `test_size` exceeds 1.

    Logs:
        INFO: If the dataset is successfully split.
        ERROR: If the sum of `valid_size` and `test_size` is greater than 1.
    """
    if balance_composition:
        db_dict = random_split(
            self.get_database("processed"),
            target_property,
            valid_size=valid_size,
            test_size=test_size,
            seed=seed,
        )
    else:
        dev_size = valid_size + test_size
        if abs(dev_size - test_size) < 1e-8:
            train_, test_ = train_test_split(
                self.get_database("processed"), test_size=test_size, random_state=seed
            )
            valid_ = pd.DataFrame()
        elif abs(dev_size - valid_size) < 1e-8:
            train_, valid_ = train_test_split(
                self.get_database("processed"), test_size=valid_size, random_state=seed
            )
            test_ = pd.DataFrame()
        else:
            train_, dev_ = train_test_split(
                self.get_database("processed"),
                test_size=valid_size + test_size,
                random_state=seed,
            )
            valid_, test_ = train_test_split(
                dev_, test_size=test_size / (valid_size + test_size), random_state=seed
            )

        db_dict = {"train": train_, "valid": valid_, "test": test_}

    if save_split:
        self.save_split_db(db_dict, "regressor")

    self.subset = db_dict

energy_gnome.dataset.CathodeDatabase

Bases: BaseDatabase

Source code in energy_gnome/dataset/cathodes.py
Python
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
class CathodeDatabase(BaseDatabase):
    def __init__(
        self,
        data_dir: Path = DATA_DIR,
        name: str = "cathodes",
        working_ion: str = "Li",
        battery_type: str = "insertion",
    ):
        """
        Initialize the CathodeDatabase with a root data directory and processing stage.

        Sets up the directory structure for storing data across different processing stages
        (`raw/`, `processed/`, `final/`) and initializes placeholders for database paths and data.

        Args:
            data_dir (Path, optional): Root directory path for storing data.
                                       Defaults to DATA_DIR from config.
            working_ion (str, optional): The working ion used in the dataset (e.g., 'Li').
                                         Defaults to "Li".
            battery_type (str, optional): The type of battery type (e.g., 'insertion', 'conversion').
                                          Defaults to "insertion".

        Raises:
            NotImplementedError: If the specified processing stage is not supported.
            ImmutableRawDataError: If attempting to set an unsupported processing stage.
        """
        super().__init__(name=name, data_dir=data_dir)
        self.working_ion = working_ion

        if battery_type == "insertion":
            self.battery_type = battery_type
        elif battery_type == "conversion":
            logger.error("`conversion` battery type is not yet implemented in Material Project.")
            raise NotImplementedError(
                "`conversion` battery type is not yet present in Material Project."
            )
        else:
            logger.error(
                f"Invalid battery type: {battery_type}. Must be 'insertion' or 'conversion'."
            )
            raise ValueError(
                "`battery_type` can be only `insertion` or `conversion` (not yet implemented)"
            )

        # Initialize directories, paths, and databases for each stage
        self.database_directories = {
            stage: self.data_dir / stage / "cathodes" / battery_type / working_ion
            for stage in self.processing_stages
        }
        for stage_dir in self.database_directories.values():
            stage_dir.mkdir(parents=True, exist_ok=True)

        self.database_paths = {
            stage: dir_path / "database.json"
            for stage, dir_path in self.database_directories.items()
        }

        self.databases = {stage: pd.DataFrame() for stage in self.processing_stages}
        self._battery_models = pd.DataFrame()

    def retrieve_remote(self, mute_progress_bars: bool = True) -> pd.DataFrame:
        """
        Retrieve models from the Material Project API.

        Wrapper method to call `retrieve_models`.

        Args:
            mute_progress_bars (bool, optional):
                If `True`, mutes the Material Project API progress bars.
                Defaults to `True`.

        Returns:
            pd.DataFrame: DataFrame containing the retrieved models.
        """
        return self.retrieve_models(mute_progress_bars=mute_progress_bars)

    def retrieve_models(self, mute_progress_bars: bool = True) -> pd.DataFrame:
        """
        Retrieve battery models from the Materials Project API.

        Connects to the Material Project API using MPRester, queries for materials
        based on the working ion and processing stage, and retrieves the specified fields.
        Cleans the data by removing entries with missing critical identifiers.

        Args:
            mute_progress_bars (bool, optional):
                If `True`, mutes the Material Project API progress bars.
                Defaults to `True`.

        Returns:
            pd.DataFrame: DataFrame containing the retrieved and cleaned models.

        Raises:
            Exception: If the API query fails.
        """
        mp_api_key = get_mp_api_key()
        logger.debug("MP querying for insertion battery models.")

        with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
            try:
                query = mpr.materials.insertion_electrodes.search(
                    working_ion=self.working_ion, fields=BAT_FIELDS
                )
                logger.info(
                    f"MP query successful, {len(query)} {self.working_ion}-ion batteries found."
                )
            except Exception as e:
                raise e
        logger.debug("Converting MP query results into DataFrame.")
        battery_models_database = convert_my_query_to_dataframe(
            query, mute_progress_bars=mute_progress_bars
        )

        # Fast cleaning
        logger.debug("Removing NaN")
        battery_models_database = battery_models_database.dropna(
            axis=0, how="any", subset=["id_charge", "id_discharge"]
        )
        battery_models_database = battery_models_database.dropna(axis=1, how="all")
        self._battery_models = battery_models_database
        logger.success(f"{self.working_ion}-ion batteries model retrieved successfully.")
        return self._battery_models

    def compare_databases(self, new_db: pd.DataFrame, stage: str) -> pd.DataFrame:
        """
        Compare two databases and identify new entry IDs.

        Args:
            new_db (pd.DataFrame): New database to compare.
            stage (str): Processing stage ("raw", "processed", "final").

        Returns:
            pd.DataFrame: Subset of `new_db` containing only new entry IDs.
        """
        old_db = self.get_database(stage=stage)
        if not old_db.empty:
            new_ids_set = set(new_db["battery_id"])
            old_ids_set = set(old_db["battery_id"])
            new_ids_only = new_ids_set - old_ids_set
            logger.debug(f"Found {len(new_ids_only)} new battery IDs in the new database.")
            return new_db[new_db["battery_id"].isin(new_ids_only)]
        else:
            logger.warning("Nothing to compare here...")
            return new_db

    def backup_and_changelog(
        self,
        old_db: pd.DataFrame,
        new_db: pd.DataFrame,
        differences: pd.Series,
        stage: str,
    ) -> None:
        """
        Backup the old database and update the changelog with identified differences.

        Creates a backup of the existing database and appends a changelog entry detailing
        the differences between the old and new databases. The changelog includes
        information such as entry identifiers, formulas, and last updated timestamps.

        Args:
            old_db (pd.DataFrame): The existing database before updates.
            new_db (pd.DataFrame): The new database containing updates.
            differences (pd.Series): Series of identifiers that are new or updated.
            stage (str): The processing stage ('raw', 'processed', 'final').
        """
        if stage not in self.processing_stages:
            logger.error(f"Invalid stage: {stage}. Must be one of {self.processing_stages}.")
            raise ValueError(f"stage must be one of {self.processing_stages}.")

        backup_path = self.database_directories[stage] / "old_database.json"
        try:
            old_db.to_json(backup_path)
            logger.debug(f"Old database backed up to {backup_path}")
        except Exception as e:
            logger.error(f"Failed to backup old database to {backup_path}: {e}")
            raise OSError(f"Failed to backup old database to {backup_path}: {e}") from e

        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        changelog_path = self.database_directories[stage] / "changelog.txt"
        changelog_entries = [
            f"= Change Log - {timestamp} ".ljust(70, "=") + "\n",
            "Difference old_database.json VS database.json\n",
            f"{'ID':<15}{'Formula':<30}{'Last Updated (MP)':<25}\n",
            "-" * 70 + "\n",
        ]
        # Tailoring respect father class
        for identifier in differences["battery_id"]:
            row = new_db.loc[new_db["battery_id"] == identifier]
            if not row.empty:
                formula = row["battery_formula"].values[0]
                last_updated = row["last_updated"].values[0]
            else:
                formula = "N/A"
                last_updated = "N/A"
            changelog_entries.append(f"{identifier:<15}{formula:<30}{last_updated:<20}\n")

        try:
            with open(changelog_path, "a") as file:
                file.writelines(changelog_entries)
            logger.debug(f"Changelog updated at {changelog_path} with {len(differences)} changes.")
        except Exception as e:
            logger.error(f"Failed to update changelog at {changelog_path}: {e}")
            raise OSError(f"Failed to update changelog at {changelog_path}: {e}") from e

    def compare_and_update(self, new_db: pd.DataFrame, stage: str) -> pd.DataFrame:
        """
        Compare and update the database with new entries.

        Identifies new entries and updates the database accordingly. Ensures that raw data
        remains immutable by preventing updates unless explicitly allowed.

        Args:
            new_db (pd.DataFrame): New database to compare.
            stage (str): Processing stage ("raw", "processed", "final").

        Returns:
            pd.DataFrame: Updated database containing new entries.

        Raises:
            ImmutableRawDataError: If attempting to modify immutable raw data.
        """
        old_db = self.get_database(stage=stage)
        db_diff = self.compare_databases(new_db, stage)
        if not db_diff.empty:
            logger.warning(f"The new database contains {len(db_diff)} new items.")

            if stage == "raw" and not self._update_raw:
                logger.error("Raw data must be treated as immutable!")
                logger.error(
                    "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
                )
                raise ImmutableRawDataError(
                    "Raw data must be treated as immutable!\n"
                    "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
                )
            else:
                if stage == "raw":
                    logger.info(
                        "Be careful you are changing the raw data which must be treated as immutable!"
                    )
                logger.info(
                    f"Updating the {stage} data and saving it in {self.database_paths[stage]}."
                )
                self.backup_and_changelog(
                    old_db,
                    new_db,
                    db_diff,
                    stage,
                )
                self.databases[stage] = new_db
                self.save_database(stage)
        else:
            logger.info("No new items found. No update required.")

    def retrieve_materials(
        self, stage: str, charge_state: str, mute_progress_bars: bool = True
    ) -> list[Any]:
        """
        Retrieve material structures from the Material Project API.

        Fetches material structures based on the processing stage and charge state.

        Args:
            stage (str): Processing stage ('raw', 'processed', 'final').
            charge_state (str): Cathode charge state ('charge', 'discharge').
            mute_progress_bars (bool, optional): Disable progress bar if True. Defaults to True.

        Returns:
            List[Any]: List of retrieved material objects.

        Raises:
            ValueError: If the charge_state is invalid.
            MissingData: If the required data is missing in the database.
        """
        if charge_state not in ["charge", "discharge"]:
            logger.error(f"Invalid charge_state: {charge_state}. Must be 'charge' or 'discharge'.")
            raise ValueError("charge_state must be 'charge' or 'discharge'.")

        material_ids = self.databases[stage][f"id_{charge_state}"].tolist()
        if not material_ids:
            logger.warning(
                f"No material IDs found for stage '{stage}' and charge_state '{charge_state}'."
            )
            raise MissingData(
                f"No material IDs found for stage '{stage}' and charge_state '{charge_state}'."
            )

        logger.debug(
            f"Retrieving materials for stage '{stage}' and charge_state '{charge_state}'."
        )
        query = get_material_by_id(
            material_ids,
            mute_progress_bars=mute_progress_bars,
        )
        return query

    def _add_materials_properties_columns(self, stage: str, charge_state: str) -> pd.DataFrame:
        """
        Add material properties columns to the database for a given cathode state.

        Args:
            stage (str): Processing stage ('raw', 'processed', 'final').
            charge_state (str): Cathode charge state ('charge', 'discharge').

        Returns:
            pd.DataFrame: Updated database with material properties columns.

        Raises:
            ImmutableRawDataError: If attempting to modify immutable raw data.
        """
        if charge_state not in ["charge", "discharge"]:
            logger.error(f"Invalid charge_state: {charge_state}. Must be 'charge' or 'discharge'.")
            raise ValueError("charge_state must be 'charge' or 'discharge'.")

        if stage == "raw" and not self._update_raw:
            logger.error("Raw data must be treated as immutable!")
            logger.error(
                "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
            )
            raise ImmutableRawDataError(
                "Raw data must be treated as immutable!\n"
                "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
            )
        else:
            if stage == "raw":
                logger.info(
                    "Be careful you are changing the raw data which must be treated as immutable!"
                )
            logger.debug(
                f"Adding material properties to {stage} data for cathode state: {charge_state}"
            )
            for property_name, dtype in MAT_PROPERTIES.items():
                column_name = f"{charge_state}_{property_name}"
                if column_name not in self.databases[stage].columns:
                    logger.debug(f"Adding missing column: {column_name} with dtype {dtype}")
                    self.databases[stage][column_name] = pd.Series(dtype=dtype)

    def add_material_properties(
        self,
        stage: str,
        materials_mp_query: list,
        charge_state: str,
        mute_progress_bars: bool = True,
    ) -> pd.DataFrame:
        """
        Add material properties to the database from Material Project query results.

        Saves CIF files for each material in the query and updates the database with file paths and properties.

        Args:
            stage (str): Processing stage ('raw', 'processed', 'final').
            materials_mp_query (List[Any]): List of material query results.
            charge_state (str): The state of the cathode ('charge' or 'discharge').
            mute_progress_bars (bool, optional): Disable progress bar if True. Defaults to True.

        Returns:
            pd.DataFrame: Updated database with material properties.

        Raises:
            ImmutableRawDataError: If attempting to modify immutable raw data.
            KeyError: If a material ID is not found in the database.
        """
        if stage == "raw" and not self._update_raw:
            logger.error("Raw data must be treated as immutable!")
            logger.error(
                "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
            )
            raise ImmutableRawDataError(
                "Raw data must be treated as immutable!\n"
                "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
            )
        else:
            if stage == "raw":
                logger.info(
                    "Be careful you are changing the raw data which must be treated as immutable!"
                )
            logger.debug(
                f"Adding material properties to {stage} data for cathode state: {charge_state}"
            )

            # Ensure necessary columns are present
            self._add_materials_properties_columns(stage, self.databases[stage], charge_state)

            for material in tqdm(
                materials_mp_query,
                desc=f"Adding {charge_state} cathodes properties",
                disable=mute_progress_bars,
            ):
                try:
                    # Locate the row in the database corresponding to the material ID
                    i_row = (
                        self.databases[stage]
                        .index[self.databases[stage][f"id_{charge_state}"] == material.material_id]
                        .tolist()[0]
                    )

                    # Assign material properties to the database
                    for property_name in MAT_PROPERTIES.keys():
                        self.databases[stage].at[i_row, f"{charge_state}_{property_name}"] = (
                            getattr(material, property_name, None)
                        )
                except IndexError:
                    logger.error(f"Material ID {material.material_id} not found in the database.")
                    raise MissingData(
                        f"Material ID {material.material_id} not found in the database."
                    )
                except Exception as e:
                    logger.error(
                        f"Failed to add properties for Material ID {material.material_id}: {e}"
                    )
                    raise e

        logger.info(f"Material properties for '{charge_state}' cathodes added successfully.")

    def save_cif_files(
        self,
        stage: str,
        materials_mp_query: list,
        charge_state: str,
        mute_progress_bars: bool = True,
    ) -> None:
        """
        Save CIF files for materials and update the database accordingly.

        Manages the saving of CIF files for each material and updates the database with
        the file paths and relevant properties. Ensures that raw data remains immutable.

        Args:
            stage (str): Processing stage ('raw', 'processed', 'final').
            materials_mp_query (List[Any]): List of material query results.
            charge_state (str): The charge state of the cathode ('charge' or 'discharge').
            mute_progress_bars (bool, optional): Disable progress bar if True. Defaults to True.

        Raises:
            ImmutableRawDataError: If attempting to modify immutable raw data.
        """

        saving_dir = self.database_directories[stage] / charge_state

        if stage == "raw" and not self._update_raw:
            logger.error("Raw data must be treated as immutable!")
            logger.error(
                "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
            )
            raise ImmutableRawDataError(
                "Raw data must be treated as immutable!\n"
                "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
            )
        elif stage == "raw" and saving_dir.exists():
            logger.info(
                "Be careful you are changing the raw data which must be treated as immutable!"
            )

        # Clean the saving directory if it exists
        if saving_dir.exists():
            logger.warning(f"Cleaning the content in {saving_dir}")
            sh.rmtree(saving_dir)

        # Create the saving directory
        saving_dir.mkdir(parents=True, exist_ok=False)
        self.databases[stage][f"{charge_state}_path"] = pd.Series(dtype=str)

        # Save CIF files and update database paths
        for material in tqdm(
            materials_mp_query,
            desc=f"Saving {charge_state} cathodes",
            disable=mute_progress_bars,
        ):
            try:
                # Locate the row in the database corresponding to the material ID
                i_row = (
                    self.databases[stage]
                    .index[self.databases[stage][f"id_{charge_state}"] == material.material_id]
                    .tolist()[0]
                )

                # Define the CIF file path
                cif_path = saving_dir / f"{material.material_id}.cif"

                # Save the CIF file
                material.structure.to(filename=str(cif_path))

                # Update the database with the CIF file path
                self.databases[stage].at[i_row, f"{charge_state}_path"] = str(cif_path)

            except IndexError:
                logger.error(f"Material ID {material.material_id} not found in the database.")
                raise MissingData(f"Material ID {material.material_id} not found in the database.")
            except Exception as e:
                logger.error(f"Failed to save CIF for Material ID {material.material_id}: {e}")
                raise OSError(
                    f"Failed to save CIF for Material ID {material.material_id}: {e}"
                ) from e

        # Save the updated database
        self.save_database(stage)
        logger.info(f"CIF files for stage '{stage}' saved and database updated successfully.")

    def copy_cif_files(
        self,
        stage: str,
        charge_state: str,
        mute_progress_bars: bool = True,
    ) -> None:
        """
        Copy CIF files from the raw stage to another processing stage.

        Copies CIF files corresponding to the specified cathode state from the 'raw'
        processing stage to the target stage. Updates the database with the new file paths.

        Args:
            stage (str): Target processing stage ('processed', 'final').
            charge_state (str): The charge state of the cathode ('charge' or 'discharge').
            mute_progress_bars (bool, optional): Disable progress bar if True. Defaults to True.

        Raises:
            ValueError: If the target stage is 'raw'.
            MissingData: If the source CIF directory does not exist or is empty.
        """
        if stage == "raw":
            logger.error("Stage argument cannot be 'raw'.")
            logger.error("You can only copy from 'raw' to other stages, not to 'raw' itself.")
            raise ValueError("Stage argument cannot be 'raw'.")

        source_dir = self.database_directories["raw"] / charge_state
        saving_dir = self.database_directories[stage] / charge_state

        # Clean the saving directory if it exists
        if saving_dir.exists():
            logger.warning(f"Cleaning the content in {saving_dir}")
            sh.rmtree(saving_dir)

        # Check if source CIF directory exists and is not empty
        if not source_dir.exists() or not any(source_dir.iterdir()):
            logger.warning(
                f"The raw CIF directory does not exist or is empty. Check: {source_dir}"
            )
            raise MissingData(
                f"The raw CIF directory does not exist or is empty. Check: {source_dir}"
            )

        # Create the saving directory
        saving_dir.mkdir(parents=True, exist_ok=False)
        self.databases[stage][f"{charge_state}_path"] = pd.Series(dtype=str)

        # Copy CIF files and update database paths
        for material_id in tqdm(
            self.databases[stage][f"id_{charge_state}"],
            desc=f"Copying {charge_state} cathodes ('raw' -> '{stage}')",
            disable=mute_progress_bars,
        ):
            try:
                # Locate the row in the database corresponding to the material ID
                i_row = (
                    self.databases[stage]
                    .index[self.databases[stage][f"id_{charge_state}"] == material_id]
                    .tolist()[0]
                )

                # Define source and destination CIF file paths
                source_cif_path = source_dir / f"{material_id}.cif"
                cif_path = saving_dir / f"{material_id}.cif"

                # Copy the CIF file
                sh.copyfile(source_cif_path, cif_path)

                # Update the database with the new CIF file path
                self.databases[stage].at[i_row, f"{charge_state}_path"] = str(cif_path)

            except IndexError:
                logger.error(f"Material ID {material_id} not found in the database.")
                raise MissingData(f"Material ID {material_id} not found in the database.")
            except Exception as e:
                logger.error(f"Failed to copy CIF for Material ID {material_id}: {e}")
                raise OSError(f"Failed to copy CIF for Material ID {material_id}: {e}") from e

        # Save the updated database
        self.save_database(stage)
        logger.info(f"CIF files copied to stage '{stage}' and database updated successfully.")

    def load_interim(self, subset: str = "training") -> pd.DataFrame:
        """
        Load the existing interim databases.

        Checks for the presence of an existing database file for the given subset
        and loads it into a pandas DataFrame. If the database file does not exist,
        logs a warning and returns an empty DataFrame.

        Args:
            set (str): The interim subset ('training', 'validation', 'testing').

        Returns:
            pd.DataFrame: The loaded database or an empty DataFrame if not found.
        """
        if subset not in self.interim_sets:
            logger.error(f"Invalid set: {subset}. Must be one of {self.interim_sets}.")
            raise ValueError(f"set must be one of {self.interim_sets}.")

        db_name = subset + "_db.json"
        db_path = INTERIM_DATA_DIR / "cathodes" / db_name
        if db_path.exists():
            self.subset[subset] = pd.read_json(db_path)
            logger.debug(f"Loaded existing database from {db_path}")
        else:
            logger.warning(f"No existing database found at {db_path}")
        return self.subset[subset]

    def __repr__(self) -> str:
        """
        Text representation of the CathodeDatabase instance.
        Used for `print()` and `str()` calls.

        Returns:
            str: ASCII table representation of the database
        """
        # Gather information about each stage
        data = {
            "Stage": [],
            "Entries": [],
            "Last Modified": [],
            "Size": [],
            "Storage Path": [],
        }

        # Calculate column widths
        widths = [10, 8, 17, 10, 55]

        for stage in self.processing_stages:
            # Get database info
            db = self.databases[stage]
            path = self.database_paths[stage]

            # Get file modification time and size if file exists
            if path.exists():
                modified = path.stat().st_mtime
                modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
                size = path.stat().st_size / 1024  # Convert to KB
                size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
            else:
                modified_time = "Not created"
                size_str = "0 KB"

            path_str = str(path.resolve())
            if len(path_str) > widths[4]:
                path_str = ".." + path_str[len(path_str) - widths[4] + 3 :]

            # Append data
            data["Stage"].append(stage.capitalize())
            data["Entries"].append(len(db))
            data["Last Modified"].append(modified_time)
            data["Size"].append(size_str)
            data["Storage Path"].append(path_str)

        # Create DataFrame
        info_df = pd.DataFrame(data)

        # Text representation for terminal/print
        def create_separator(widths):
            return "+" + "+".join("-" * (w + 1) for w in widths) + "+"

        # Create the text representation
        lines = []

        # Add title
        title = f" {self.__class__.__name__} Summary "
        lines.append(f"\n{title:=^{sum(widths) + len(widths) * 2 + 1}}")

        # Add header
        separator = create_separator(widths)
        lines.append(separator)

        header = (
            "|" + "|".join(f" {col:<{widths[i]}}" for i, col in enumerate(info_df.columns)) + "|"
        )
        lines.append(header)
        lines.append(separator)

        # Add data rows
        for _, row in info_df.iterrows():
            line = "|" + "|".join(f" {str(val):<{widths[i]}}" for i, val in enumerate(row)) + "|"
            lines.append(line)

        # Add bottom separator
        lines.append(separator)

        # Add additional info
        lines.append(f"\nWorking Ion: {self.working_ion}")
        lines.append(f"Battery Type: {self.battery_type}")

        return "\n".join(lines)

    def _repr_html_(self) -> str:
        """
        HTML representation of the CathodeDatabase instance.
        Used for Jupyter notebook display.

        Returns:
            str: HTML representation of the database
        """
        # Gather information about each stage
        data = {
            "Stage": [],
            "Entries": [],
            "Last Modified": [],
            "Size": [],
            "Storage Path": [],
        }

        for stage in self.processing_stages:
            # Get database info
            db = self.databases[stage]
            path = self.database_paths[stage]

            # Get file modification time and size if file exists
            if path.exists():
                modified = path.stat().st_mtime
                modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
                size = path.stat().st_size / 1024  # Convert to KB
                size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
            else:
                modified_time = "Not created"
                size_str = "0 KB"

            # Append data
            data["Stage"].append(stage.capitalize())
            data["Entries"].append(len(db))
            data["Last Modified"].append(modified_time)
            data["Size"].append(size_str)
            data["Storage Path"].append(str(path.resolve()))

        # Create DataFrame
        info_df = pd.DataFrame(data)

        # Generate header row
        header_cells = " ".join(
            f'<th style="padding: 12px 15px; text-align: left;">{col}</th>'
            for col in info_df.columns
        )

        # Generate table rows
        table_rows = ""
        for _, row in info_df.iterrows():
            cells = "".join(f'<td style="padding: 12px 15px;">{val}</td>' for val in row)
            table_rows += f"<tr style='border-bottom: 1px solid #e9ecef;'>{cells}</tr>"

        # Create the complete HTML
        html = (
            """<style>
                @media (prefers-color-scheme: dark) {
                    .database-container { background-color: #1e1e1e !important; }
                    .database-title { color: #e0e0e0 !important; }
                    .database-table { background-color: #2d2d2d !important; }
                    .database-header { background-color: #4a4a4a !important; }
                    .database-cell { border-color: #404040 !important; }
                    .database-info { color: #b0b0b0 !important; }
                }
            </style>"""
            '<div style="font-family: Arial, sans-serif; padding: 20px; background:transparent; '
            'border-radius: 8px;">'
            f'<h3 style="color: #58bac7; margin-bottom: 15px;">{self.__class__.__name__}</h3>'
            '<div style="overflow-x: auto;">'
            '<table class="database-table" style="border-collapse: collapse; width: 100%;'
            ' box-shadow: 0 1px 3px rgba(0,0,0,0.1); background:transparent;">'
            # '<table style="border-collapse: collapse; width: 100%; background-color: white; '
            # 'box-shadow: 0 1px 3px rgba(0,0,0,0.1);">'
            "<thead>"
            f'<tr style="background-color: #58bac7; color: white;">{header_cells}</tr>'
            "</thead>"
            f"<tbody>{table_rows}</tbody>"
            "</table>"
            "</div>"
            '<div style="margin-top: 10px; color: #666; font-size: 1.1em;">'
            f"Working Ion: {self.working_ion}<br>"
            f"Battery Type: {self.battery_type}"
            "</div>"
            "</div>"
        )
        return html

__init__(data_dir=DATA_DIR, name='cathodes', working_ion='Li', battery_type='insertion')

Sets up the directory structure for storing data across different processing stages (raw/, processed/, final/) and initializes placeholders for database paths and data.

Parameters:

Name Type Description Default
data_dir Path

Root directory path for storing data. Defaults to DATA_DIR from config.

DATA_DIR
working_ion str

The working ion used in the dataset (e.g., 'Li'). Defaults to "Li".

'Li'
battery_type str

The type of battery type (e.g., 'insertion', 'conversion'). Defaults to "insertion".

'insertion'

Raises:

Type Description
NotImplementedError

If the specified processing stage is not supported.

ImmutableRawDataError

If attempting to set an unsupported processing stage.

Source code in energy_gnome/dataset/cathodes.py
Python
def __init__(
    self,
    data_dir: Path = DATA_DIR,
    name: str = "cathodes",
    working_ion: str = "Li",
    battery_type: str = "insertion",
):
    """
    Initialize the CathodeDatabase with a root data directory and processing stage.

    Sets up the directory structure for storing data across different processing stages
    (`raw/`, `processed/`, `final/`) and initializes placeholders for database paths and data.

    Args:
        data_dir (Path, optional): Root directory path for storing data.
                                   Defaults to DATA_DIR from config.
        working_ion (str, optional): The working ion used in the dataset (e.g., 'Li').
                                     Defaults to "Li".
        battery_type (str, optional): The type of battery type (e.g., 'insertion', 'conversion').
                                      Defaults to "insertion".

    Raises:
        NotImplementedError: If the specified processing stage is not supported.
        ImmutableRawDataError: If attempting to set an unsupported processing stage.
    """
    super().__init__(name=name, data_dir=data_dir)
    self.working_ion = working_ion

    if battery_type == "insertion":
        self.battery_type = battery_type
    elif battery_type == "conversion":
        logger.error("`conversion` battery type is not yet implemented in Material Project.")
        raise NotImplementedError(
            "`conversion` battery type is not yet present in Material Project."
        )
    else:
        logger.error(
            f"Invalid battery type: {battery_type}. Must be 'insertion' or 'conversion'."
        )
        raise ValueError(
            "`battery_type` can be only `insertion` or `conversion` (not yet implemented)"
        )

    # Initialize directories, paths, and databases for each stage
    self.database_directories = {
        stage: self.data_dir / stage / "cathodes" / battery_type / working_ion
        for stage in self.processing_stages
    }
    for stage_dir in self.database_directories.values():
        stage_dir.mkdir(parents=True, exist_ok=True)

    self.database_paths = {
        stage: dir_path / "database.json"
        for stage, dir_path in self.database_directories.items()
    }

    self.databases = {stage: pd.DataFrame() for stage in self.processing_stages}
    self._battery_models = pd.DataFrame()

__repr__()

Text representation of the CathodeDatabase instance. Used for print() and str() calls.

Returns:

Name Type Description
str str

ASCII table representation of the database

Source code in energy_gnome/dataset/cathodes.py
Python
def __repr__(self) -> str:
    """
    Text representation of the CathodeDatabase instance.
    Used for `print()` and `str()` calls.

    Returns:
        str: ASCII table representation of the database
    """
    # Gather information about each stage
    data = {
        "Stage": [],
        "Entries": [],
        "Last Modified": [],
        "Size": [],
        "Storage Path": [],
    }

    # Calculate column widths
    widths = [10, 8, 17, 10, 55]

    for stage in self.processing_stages:
        # Get database info
        db = self.databases[stage]
        path = self.database_paths[stage]

        # Get file modification time and size if file exists
        if path.exists():
            modified = path.stat().st_mtime
            modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
            size = path.stat().st_size / 1024  # Convert to KB
            size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
        else:
            modified_time = "Not created"
            size_str = "0 KB"

        path_str = str(path.resolve())
        if len(path_str) > widths[4]:
            path_str = ".." + path_str[len(path_str) - widths[4] + 3 :]

        # Append data
        data["Stage"].append(stage.capitalize())
        data["Entries"].append(len(db))
        data["Last Modified"].append(modified_time)
        data["Size"].append(size_str)
        data["Storage Path"].append(path_str)

    # Create DataFrame
    info_df = pd.DataFrame(data)

    # Text representation for terminal/print
    def create_separator(widths):
        return "+" + "+".join("-" * (w + 1) for w in widths) + "+"

    # Create the text representation
    lines = []

    # Add title
    title = f" {self.__class__.__name__} Summary "
    lines.append(f"\n{title:=^{sum(widths) + len(widths) * 2 + 1}}")

    # Add header
    separator = create_separator(widths)
    lines.append(separator)

    header = (
        "|" + "|".join(f" {col:<{widths[i]}}" for i, col in enumerate(info_df.columns)) + "|"
    )
    lines.append(header)
    lines.append(separator)

    # Add data rows
    for _, row in info_df.iterrows():
        line = "|" + "|".join(f" {str(val):<{widths[i]}}" for i, val in enumerate(row)) + "|"
        lines.append(line)

    # Add bottom separator
    lines.append(separator)

    # Add additional info
    lines.append(f"\nWorking Ion: {self.working_ion}")
    lines.append(f"Battery Type: {self.battery_type}")

    return "\n".join(lines)

add_material_properties(stage, materials_mp_query, charge_state, mute_progress_bars=True)

Add material properties to the database from Material Project query results.

Saves CIF files for each material in the query and updates the database with file paths and properties.

Parameters:

Name Type Description Default
stage str

Processing stage ('raw', 'processed', 'final').

required
materials_mp_query List[Any]

List of material query results.

required
charge_state str

The state of the cathode ('charge' or 'discharge').

required
mute_progress_bars bool

Disable progress bar if True. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Updated database with material properties.

Raises:

Type Description
ImmutableRawDataError

If attempting to modify immutable raw data.

KeyError

If a material ID is not found in the database.

Source code in energy_gnome/dataset/cathodes.py
Python
def add_material_properties(
    self,
    stage: str,
    materials_mp_query: list,
    charge_state: str,
    mute_progress_bars: bool = True,
) -> pd.DataFrame:
    """
    Add material properties to the database from Material Project query results.

    Saves CIF files for each material in the query and updates the database with file paths and properties.

    Args:
        stage (str): Processing stage ('raw', 'processed', 'final').
        materials_mp_query (List[Any]): List of material query results.
        charge_state (str): The state of the cathode ('charge' or 'discharge').
        mute_progress_bars (bool, optional): Disable progress bar if True. Defaults to True.

    Returns:
        pd.DataFrame: Updated database with material properties.

    Raises:
        ImmutableRawDataError: If attempting to modify immutable raw data.
        KeyError: If a material ID is not found in the database.
    """
    if stage == "raw" and not self._update_raw:
        logger.error("Raw data must be treated as immutable!")
        logger.error(
            "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
        )
        raise ImmutableRawDataError(
            "Raw data must be treated as immutable!\n"
            "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
        )
    else:
        if stage == "raw":
            logger.info(
                "Be careful you are changing the raw data which must be treated as immutable!"
            )
        logger.debug(
            f"Adding material properties to {stage} data for cathode state: {charge_state}"
        )

        # Ensure necessary columns are present
        self._add_materials_properties_columns(stage, self.databases[stage], charge_state)

        for material in tqdm(
            materials_mp_query,
            desc=f"Adding {charge_state} cathodes properties",
            disable=mute_progress_bars,
        ):
            try:
                # Locate the row in the database corresponding to the material ID
                i_row = (
                    self.databases[stage]
                    .index[self.databases[stage][f"id_{charge_state}"] == material.material_id]
                    .tolist()[0]
                )

                # Assign material properties to the database
                for property_name in MAT_PROPERTIES.keys():
                    self.databases[stage].at[i_row, f"{charge_state}_{property_name}"] = (
                        getattr(material, property_name, None)
                    )
            except IndexError:
                logger.error(f"Material ID {material.material_id} not found in the database.")
                raise MissingData(
                    f"Material ID {material.material_id} not found in the database."
                )
            except Exception as e:
                logger.error(
                    f"Failed to add properties for Material ID {material.material_id}: {e}"
                )
                raise e

    logger.info(f"Material properties for '{charge_state}' cathodes added successfully.")

backup_and_changelog(old_db, new_db, differences, stage)

Backup the old database and update the changelog with identified differences.

Creates a backup of the existing database and appends a changelog entry detailing the differences between the old and new databases. The changelog includes information such as entry identifiers, formulas, and last updated timestamps.

Parameters:

Name Type Description Default
old_db DataFrame

The existing database before updates.

required
new_db DataFrame

The new database containing updates.

required
differences Series

Series of identifiers that are new or updated.

required
stage str

The processing stage ('raw', 'processed', 'final').

required
Source code in energy_gnome/dataset/cathodes.py
Python
def backup_and_changelog(
    self,
    old_db: pd.DataFrame,
    new_db: pd.DataFrame,
    differences: pd.Series,
    stage: str,
) -> None:
    """
    Backup the old database and update the changelog with identified differences.

    Creates a backup of the existing database and appends a changelog entry detailing
    the differences between the old and new databases. The changelog includes
    information such as entry identifiers, formulas, and last updated timestamps.

    Args:
        old_db (pd.DataFrame): The existing database before updates.
        new_db (pd.DataFrame): The new database containing updates.
        differences (pd.Series): Series of identifiers that are new or updated.
        stage (str): The processing stage ('raw', 'processed', 'final').
    """
    if stage not in self.processing_stages:
        logger.error(f"Invalid stage: {stage}. Must be one of {self.processing_stages}.")
        raise ValueError(f"stage must be one of {self.processing_stages}.")

    backup_path = self.database_directories[stage] / "old_database.json"
    try:
        old_db.to_json(backup_path)
        logger.debug(f"Old database backed up to {backup_path}")
    except Exception as e:
        logger.error(f"Failed to backup old database to {backup_path}: {e}")
        raise OSError(f"Failed to backup old database to {backup_path}: {e}") from e

    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    changelog_path = self.database_directories[stage] / "changelog.txt"
    changelog_entries = [
        f"= Change Log - {timestamp} ".ljust(70, "=") + "\n",
        "Difference old_database.json VS database.json\n",
        f"{'ID':<15}{'Formula':<30}{'Last Updated (MP)':<25}\n",
        "-" * 70 + "\n",
    ]
    # Tailoring respect father class
    for identifier in differences["battery_id"]:
        row = new_db.loc[new_db["battery_id"] == identifier]
        if not row.empty:
            formula = row["battery_formula"].values[0]
            last_updated = row["last_updated"].values[0]
        else:
            formula = "N/A"
            last_updated = "N/A"
        changelog_entries.append(f"{identifier:<15}{formula:<30}{last_updated:<20}\n")

    try:
        with open(changelog_path, "a") as file:
            file.writelines(changelog_entries)
        logger.debug(f"Changelog updated at {changelog_path} with {len(differences)} changes.")
    except Exception as e:
        logger.error(f"Failed to update changelog at {changelog_path}: {e}")
        raise OSError(f"Failed to update changelog at {changelog_path}: {e}") from e

compare_and_update(new_db, stage)

Compare and update the database with new entries.

Identifies new entries and updates the database accordingly. Ensures that raw data remains immutable by preventing updates unless explicitly allowed.

Parameters:

Name Type Description Default
new_db DataFrame

New database to compare.

required
stage str

Processing stage ("raw", "processed", "final").

required

Returns:

Type Description
DataFrame

pd.DataFrame: Updated database containing new entries.

Raises:

Type Description
ImmutableRawDataError

If attempting to modify immutable raw data.

Source code in energy_gnome/dataset/cathodes.py
Python
def compare_and_update(self, new_db: pd.DataFrame, stage: str) -> pd.DataFrame:
    """
    Compare and update the database with new entries.

    Identifies new entries and updates the database accordingly. Ensures that raw data
    remains immutable by preventing updates unless explicitly allowed.

    Args:
        new_db (pd.DataFrame): New database to compare.
        stage (str): Processing stage ("raw", "processed", "final").

    Returns:
        pd.DataFrame: Updated database containing new entries.

    Raises:
        ImmutableRawDataError: If attempting to modify immutable raw data.
    """
    old_db = self.get_database(stage=stage)
    db_diff = self.compare_databases(new_db, stage)
    if not db_diff.empty:
        logger.warning(f"The new database contains {len(db_diff)} new items.")

        if stage == "raw" and not self._update_raw:
            logger.error("Raw data must be treated as immutable!")
            logger.error(
                "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
            )
            raise ImmutableRawDataError(
                "Raw data must be treated as immutable!\n"
                "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
            )
        else:
            if stage == "raw":
                logger.info(
                    "Be careful you are changing the raw data which must be treated as immutable!"
                )
            logger.info(
                f"Updating the {stage} data and saving it in {self.database_paths[stage]}."
            )
            self.backup_and_changelog(
                old_db,
                new_db,
                db_diff,
                stage,
            )
            self.databases[stage] = new_db
            self.save_database(stage)
    else:
        logger.info("No new items found. No update required.")

compare_databases(new_db, stage)

Compare two databases and identify new entry IDs.

Parameters:

Name Type Description Default
new_db DataFrame

New database to compare.

required
stage str

Processing stage ("raw", "processed", "final").

required

Returns:

Type Description
DataFrame

pd.DataFrame: Subset of new_db containing only new entry IDs.

Source code in energy_gnome/dataset/cathodes.py
Python
def compare_databases(self, new_db: pd.DataFrame, stage: str) -> pd.DataFrame:
    """
    Compare two databases and identify new entry IDs.

    Args:
        new_db (pd.DataFrame): New database to compare.
        stage (str): Processing stage ("raw", "processed", "final").

    Returns:
        pd.DataFrame: Subset of `new_db` containing only new entry IDs.
    """
    old_db = self.get_database(stage=stage)
    if not old_db.empty:
        new_ids_set = set(new_db["battery_id"])
        old_ids_set = set(old_db["battery_id"])
        new_ids_only = new_ids_set - old_ids_set
        logger.debug(f"Found {len(new_ids_only)} new battery IDs in the new database.")
        return new_db[new_db["battery_id"].isin(new_ids_only)]
    else:
        logger.warning("Nothing to compare here...")
        return new_db

copy_cif_files(stage, charge_state, mute_progress_bars=True)

Copy CIF files from the raw stage to another processing stage.

Copies CIF files corresponding to the specified cathode state from the 'raw' processing stage to the target stage. Updates the database with the new file paths.

Parameters:

Name Type Description Default
stage str

Target processing stage ('processed', 'final').

required
charge_state str

The charge state of the cathode ('charge' or 'discharge').

required
mute_progress_bars bool

Disable progress bar if True. Defaults to True.

True

Raises:

Type Description
ValueError

If the target stage is 'raw'.

MissingData

If the source CIF directory does not exist or is empty.

Source code in energy_gnome/dataset/cathodes.py
Python
def copy_cif_files(
    self,
    stage: str,
    charge_state: str,
    mute_progress_bars: bool = True,
) -> None:
    """
    Copy CIF files from the raw stage to another processing stage.

    Copies CIF files corresponding to the specified cathode state from the 'raw'
    processing stage to the target stage. Updates the database with the new file paths.

    Args:
        stage (str): Target processing stage ('processed', 'final').
        charge_state (str): The charge state of the cathode ('charge' or 'discharge').
        mute_progress_bars (bool, optional): Disable progress bar if True. Defaults to True.

    Raises:
        ValueError: If the target stage is 'raw'.
        MissingData: If the source CIF directory does not exist or is empty.
    """
    if stage == "raw":
        logger.error("Stage argument cannot be 'raw'.")
        logger.error("You can only copy from 'raw' to other stages, not to 'raw' itself.")
        raise ValueError("Stage argument cannot be 'raw'.")

    source_dir = self.database_directories["raw"] / charge_state
    saving_dir = self.database_directories[stage] / charge_state

    # Clean the saving directory if it exists
    if saving_dir.exists():
        logger.warning(f"Cleaning the content in {saving_dir}")
        sh.rmtree(saving_dir)

    # Check if source CIF directory exists and is not empty
    if not source_dir.exists() or not any(source_dir.iterdir()):
        logger.warning(
            f"The raw CIF directory does not exist or is empty. Check: {source_dir}"
        )
        raise MissingData(
            f"The raw CIF directory does not exist or is empty. Check: {source_dir}"
        )

    # Create the saving directory
    saving_dir.mkdir(parents=True, exist_ok=False)
    self.databases[stage][f"{charge_state}_path"] = pd.Series(dtype=str)

    # Copy CIF files and update database paths
    for material_id in tqdm(
        self.databases[stage][f"id_{charge_state}"],
        desc=f"Copying {charge_state} cathodes ('raw' -> '{stage}')",
        disable=mute_progress_bars,
    ):
        try:
            # Locate the row in the database corresponding to the material ID
            i_row = (
                self.databases[stage]
                .index[self.databases[stage][f"id_{charge_state}"] == material_id]
                .tolist()[0]
            )

            # Define source and destination CIF file paths
            source_cif_path = source_dir / f"{material_id}.cif"
            cif_path = saving_dir / f"{material_id}.cif"

            # Copy the CIF file
            sh.copyfile(source_cif_path, cif_path)

            # Update the database with the new CIF file path
            self.databases[stage].at[i_row, f"{charge_state}_path"] = str(cif_path)

        except IndexError:
            logger.error(f"Material ID {material_id} not found in the database.")
            raise MissingData(f"Material ID {material_id} not found in the database.")
        except Exception as e:
            logger.error(f"Failed to copy CIF for Material ID {material_id}: {e}")
            raise OSError(f"Failed to copy CIF for Material ID {material_id}: {e}") from e

    # Save the updated database
    self.save_database(stage)
    logger.info(f"CIF files copied to stage '{stage}' and database updated successfully.")

load_interim(subset='training')

Load the existing interim databases.

Checks for the presence of an existing database file for the given subset and loads it into a pandas DataFrame. If the database file does not exist, logs a warning and returns an empty DataFrame.

Parameters:

Name Type Description Default
set str

The interim subset ('training', 'validation', 'testing').

required

Returns:

Type Description
DataFrame

pd.DataFrame: The loaded database or an empty DataFrame if not found.

Source code in energy_gnome/dataset/cathodes.py
Python
def load_interim(self, subset: str = "training") -> pd.DataFrame:
    """
    Load the existing interim databases.

    Checks for the presence of an existing database file for the given subset
    and loads it into a pandas DataFrame. If the database file does not exist,
    logs a warning and returns an empty DataFrame.

    Args:
        set (str): The interim subset ('training', 'validation', 'testing').

    Returns:
        pd.DataFrame: The loaded database or an empty DataFrame if not found.
    """
    if subset not in self.interim_sets:
        logger.error(f"Invalid set: {subset}. Must be one of {self.interim_sets}.")
        raise ValueError(f"set must be one of {self.interim_sets}.")

    db_name = subset + "_db.json"
    db_path = INTERIM_DATA_DIR / "cathodes" / db_name
    if db_path.exists():
        self.subset[subset] = pd.read_json(db_path)
        logger.debug(f"Loaded existing database from {db_path}")
    else:
        logger.warning(f"No existing database found at {db_path}")
    return self.subset[subset]

retrieve_materials(stage, charge_state, mute_progress_bars=True)

Retrieve material structures from the Material Project API.

Fetches material structures based on the processing stage and charge state.

Parameters:

Name Type Description Default
stage str

Processing stage ('raw', 'processed', 'final').

required
charge_state str

Cathode charge state ('charge', 'discharge').

required
mute_progress_bars bool

Disable progress bar if True. Defaults to True.

True

Returns:

Type Description
list[Any]

List[Any]: List of retrieved material objects.

Raises:

Type Description
ValueError

If the charge_state is invalid.

MissingData

If the required data is missing in the database.

Source code in energy_gnome/dataset/cathodes.py
Python
def retrieve_materials(
    self, stage: str, charge_state: str, mute_progress_bars: bool = True
) -> list[Any]:
    """
    Retrieve material structures from the Material Project API.

    Fetches material structures based on the processing stage and charge state.

    Args:
        stage (str): Processing stage ('raw', 'processed', 'final').
        charge_state (str): Cathode charge state ('charge', 'discharge').
        mute_progress_bars (bool, optional): Disable progress bar if True. Defaults to True.

    Returns:
        List[Any]: List of retrieved material objects.

    Raises:
        ValueError: If the charge_state is invalid.
        MissingData: If the required data is missing in the database.
    """
    if charge_state not in ["charge", "discharge"]:
        logger.error(f"Invalid charge_state: {charge_state}. Must be 'charge' or 'discharge'.")
        raise ValueError("charge_state must be 'charge' or 'discharge'.")

    material_ids = self.databases[stage][f"id_{charge_state}"].tolist()
    if not material_ids:
        logger.warning(
            f"No material IDs found for stage '{stage}' and charge_state '{charge_state}'."
        )
        raise MissingData(
            f"No material IDs found for stage '{stage}' and charge_state '{charge_state}'."
        )

    logger.debug(
        f"Retrieving materials for stage '{stage}' and charge_state '{charge_state}'."
    )
    query = get_material_by_id(
        material_ids,
        mute_progress_bars=mute_progress_bars,
    )
    return query

retrieve_models(mute_progress_bars=True)

Retrieve battery models from the Materials Project API.

Connects to the Material Project API using MPRester, queries for materials based on the working ion and processing stage, and retrieves the specified fields. Cleans the data by removing entries with missing critical identifiers.

Parameters:

Name Type Description Default
mute_progress_bars bool

If True, mutes the Material Project API progress bars. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame containing the retrieved and cleaned models.

Raises:

Type Description
Exception

If the API query fails.

Source code in energy_gnome/dataset/cathodes.py
Python
def retrieve_models(self, mute_progress_bars: bool = True) -> pd.DataFrame:
    """
    Retrieve battery models from the Materials Project API.

    Connects to the Material Project API using MPRester, queries for materials
    based on the working ion and processing stage, and retrieves the specified fields.
    Cleans the data by removing entries with missing critical identifiers.

    Args:
        mute_progress_bars (bool, optional):
            If `True`, mutes the Material Project API progress bars.
            Defaults to `True`.

    Returns:
        pd.DataFrame: DataFrame containing the retrieved and cleaned models.

    Raises:
        Exception: If the API query fails.
    """
    mp_api_key = get_mp_api_key()
    logger.debug("MP querying for insertion battery models.")

    with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
        try:
            query = mpr.materials.insertion_electrodes.search(
                working_ion=self.working_ion, fields=BAT_FIELDS
            )
            logger.info(
                f"MP query successful, {len(query)} {self.working_ion}-ion batteries found."
            )
        except Exception as e:
            raise e
    logger.debug("Converting MP query results into DataFrame.")
    battery_models_database = convert_my_query_to_dataframe(
        query, mute_progress_bars=mute_progress_bars
    )

    # Fast cleaning
    logger.debug("Removing NaN")
    battery_models_database = battery_models_database.dropna(
        axis=0, how="any", subset=["id_charge", "id_discharge"]
    )
    battery_models_database = battery_models_database.dropna(axis=1, how="all")
    self._battery_models = battery_models_database
    logger.success(f"{self.working_ion}-ion batteries model retrieved successfully.")
    return self._battery_models

retrieve_remote(mute_progress_bars=True)

Retrieve models from the Material Project API.

Wrapper method to call retrieve_models.

Parameters:

Name Type Description Default
mute_progress_bars bool

If True, mutes the Material Project API progress bars. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame containing the retrieved models.

Source code in energy_gnome/dataset/cathodes.py
Python
def retrieve_remote(self, mute_progress_bars: bool = True) -> pd.DataFrame:
    """
    Retrieve models from the Material Project API.

    Wrapper method to call `retrieve_models`.

    Args:
        mute_progress_bars (bool, optional):
            If `True`, mutes the Material Project API progress bars.
            Defaults to `True`.

    Returns:
        pd.DataFrame: DataFrame containing the retrieved models.
    """
    return self.retrieve_models(mute_progress_bars=mute_progress_bars)

save_cif_files(stage, materials_mp_query, charge_state, mute_progress_bars=True)

Save CIF files for materials and update the database accordingly.

Manages the saving of CIF files for each material and updates the database with the file paths and relevant properties. Ensures that raw data remains immutable.

Parameters:

Name Type Description Default
stage str

Processing stage ('raw', 'processed', 'final').

required
materials_mp_query List[Any]

List of material query results.

required
charge_state str

The charge state of the cathode ('charge' or 'discharge').

required
mute_progress_bars bool

Disable progress bar if True. Defaults to True.

True

Raises:

Type Description
ImmutableRawDataError

If attempting to modify immutable raw data.

Source code in energy_gnome/dataset/cathodes.py
Python
def save_cif_files(
    self,
    stage: str,
    materials_mp_query: list,
    charge_state: str,
    mute_progress_bars: bool = True,
) -> None:
    """
    Save CIF files for materials and update the database accordingly.

    Manages the saving of CIF files for each material and updates the database with
    the file paths and relevant properties. Ensures that raw data remains immutable.

    Args:
        stage (str): Processing stage ('raw', 'processed', 'final').
        materials_mp_query (List[Any]): List of material query results.
        charge_state (str): The charge state of the cathode ('charge' or 'discharge').
        mute_progress_bars (bool, optional): Disable progress bar if True. Defaults to True.

    Raises:
        ImmutableRawDataError: If attempting to modify immutable raw data.
    """

    saving_dir = self.database_directories[stage] / charge_state

    if stage == "raw" and not self._update_raw:
        logger.error("Raw data must be treated as immutable!")
        logger.error(
            "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
        )
        raise ImmutableRawDataError(
            "Raw data must be treated as immutable!\n"
            "It's okay to read and copy raw data to manipulate it into new outputs, but never okay to change it in place."
        )
    elif stage == "raw" and saving_dir.exists():
        logger.info(
            "Be careful you are changing the raw data which must be treated as immutable!"
        )

    # Clean the saving directory if it exists
    if saving_dir.exists():
        logger.warning(f"Cleaning the content in {saving_dir}")
        sh.rmtree(saving_dir)

    # Create the saving directory
    saving_dir.mkdir(parents=True, exist_ok=False)
    self.databases[stage][f"{charge_state}_path"] = pd.Series(dtype=str)

    # Save CIF files and update database paths
    for material in tqdm(
        materials_mp_query,
        desc=f"Saving {charge_state} cathodes",
        disable=mute_progress_bars,
    ):
        try:
            # Locate the row in the database corresponding to the material ID
            i_row = (
                self.databases[stage]
                .index[self.databases[stage][f"id_{charge_state}"] == material.material_id]
                .tolist()[0]
            )

            # Define the CIF file path
            cif_path = saving_dir / f"{material.material_id}.cif"

            # Save the CIF file
            material.structure.to(filename=str(cif_path))

            # Update the database with the CIF file path
            self.databases[stage].at[i_row, f"{charge_state}_path"] = str(cif_path)

        except IndexError:
            logger.error(f"Material ID {material.material_id} not found in the database.")
            raise MissingData(f"Material ID {material.material_id} not found in the database.")
        except Exception as e:
            logger.error(f"Failed to save CIF for Material ID {material.material_id}: {e}")
            raise OSError(
                f"Failed to save CIF for Material ID {material.material_id}: {e}"
            ) from e

    # Save the updated database
    self.save_database(stage)
    logger.info(f"CIF files for stage '{stage}' saved and database updated successfully.")

energy_gnome.dataset.PerovskiteDatabase

Bases: BaseDatabase

Source code in energy_gnome/dataset/perovskites.py
Python
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
class PerovskiteDatabase(BaseDatabase):
    def __init__(
        self,
        name: str = "perovskites",
        data_dir: Path | str = DATA_DIR,
        external_perovproj_path: Path | str = EXTERNAL_DATA_DIR
        / Path("perovskites")
        / Path("perovproject_db.json"),
    ):
        """
        Initialize the PerovskiteDatabase with a root data directory and processing stage.

        This constructor sets up the directory structure for storing data across different
        processing stages (`raw`, `processed`, `final`). It also initializes placeholders
        for the database paths and data specific to the `PerovskiteDatabase` class, including
        a path for the external Perovskite project data.

        Args:
            name (str, optional): The name of the database. Defaults to "perovskites".
            data_dir (Path or str, optional): Root directory path for storing data. Defaults
                to `DATA_DIR` from the configuration.
            external_perovproj_path (Path or str, optional): Path to the external Perovskite
                Project database file. Defaults to `EXTERNAL_DATA_DIR / "perovskites" /
                "perovproject_db.json"`.

        Raises:
            NotImplementedError: If the specified processing stage is not supported.
            ImmutableRawDataError: If attempting to set an unsupported processing stage.
        """
        super().__init__(name=name, data_dir=data_dir)
        self._perovskites = pd.DataFrame()
        self.external_perovproj_path: Path | str = external_perovproj_path

    def _set_is_specialized(self):
        """
        Set the `is_specialized` attribute to `True`.

        This method marks the database as specialized by setting the `is_specialized`
        attribute to `True`. It is typically used to indicate that the database
        is intended for a specific class of data, corresponding to specialized
        energy materials.

        Returns:
            None
        """
        self.is_specialized = True

    def _pre_retrieve_robo(self, mute_progress_bars: bool = True) -> list[str]:
        """
        Retrieve Perovskite material IDs from the Robocrystallographer API.

        This method queries the Robocrystallographer tool through the Materials Project
        API to search for materials related to "Perovskite". It returns a list of material
        IDs corresponding to the Perovskite materials identified in the search results.

        Args:
            mute_progress_bars (bool, optional): Whether to mute progress bars for the
                API request. Defaults to `True`.

        Returns:
            (list[str]): A list of material IDs for Perovskite materials retrieved from
            Robocrystallographer.

        Raises:
            Exception: If there is an issue with the query to the Materials Project API.

        Logs:
            - INFO: If the query is successful, logging the number of Perovskite IDs found.
        """
        mp_api_key = get_mp_api_key()
        with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
            try:
                query = mpr.materials.robocrys.search(keywords=["Perovskite", "perovskite"])
                logger.info(
                    f"MP query successful, {len(query)} perovskite IDs found through Robocrystallographer."
                )
            except Exception as e:
                raise e
        ids_list_robo = [q.material_id for q in query]
        return ids_list_robo

    def _pre_retrieve_perovproj(self, mute_progress_bars: bool = True) -> list[str]:
        """
        Retrieve Perovskite material IDs from the Perovskite Project.

        This method queries the Materials Project API using the formulae from an external
        Perovskite Project file to search for matching material IDs. It returns a list of
        material IDs for Perovskite materials found through the Perovskite Project.

        Args:
            mute_progress_bars (bool, optional): Whether to mute progress bars for the
                API request. Defaults to `True`.

        Returns:
            (list[str]): A list of material IDs for Perovskite materials retrieved from
            the Perovskite Project.

        Raises:
            Exception: If there is an issue with the query to the Materials Project API.

        Logs:
            - INFO: If the query is successful, logging the number of Perovskite IDs found.
        """
        mp_api_key = get_mp_api_key()
        with open(self.external_perovproj_path) as f:
            dict_ = json.load(f)
        with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
            try:
                query = mpr.materials.summary.search(formula=dict_, fields="material_id")
                logger.info(
                    f"MP query successful, {len(query)} perovskite IDs found through Perovskite Project formulae."
                )
            except Exception as e:
                raise e
        ids_list_perovproj = [q.material_id for q in query]
        return ids_list_perovproj

    def retrieve_materials(self, mute_progress_bars: bool = True) -> pd.DataFrame:
        """
        Retrieve Perovskite materials from the Materials Project API.

        This method connects to the Materials Project API using the `MPRester`, queries
        for materials related to Perovskites, and retrieves specified properties. The data
        is then cleaned by removing rows with missing critical fields and filtering out
        metallic perovskites. The method returns a cleaned DataFrame of Perovskites.

        Args:
            mute_progress_bars (bool, optional):
                If `True`, mutes the Material Project API progress bars during the request.
                Defaults to `True`.

        Returns:
            (pd.DataFrame): A DataFrame containing the retrieved and cleaned Perovskite materials.

        Raises:
            Exception: If the API query fails or any issue occurs during data retrieval.

        Logs:
            - DEBUG: Logs the process of querying and cleaning data.
            - INFO: Logs successful query results and how many Perovskites were retrieved.
            - SUCCESS: Logs the successful retrieval of Perovskite materials.
        """
        mp_api_key = get_mp_api_key()
        ids_list_robo = self._pre_retrieve_robo(mute_progress_bars=mute_progress_bars)
        ids_list_perovproj = self._pre_retrieve_perovproj(mute_progress_bars=mute_progress_bars)
        logger.debug("MP querying for perovskites.")

        ids_list = ids_list_robo + ids_list_perovproj
        unique_ids = list()
        for x in ids_list:
            if x not in unique_ids:
                unique_ids.append(x)

        with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
            try:
                query = mpr.materials.summary.search(
                    material_ids=unique_ids, fields=MAT_PROPERTIES
                )
                logger.info(
                    f"MP query successful, {len(query)} perovskites found through Robocrystallographer and Perovskite Project formulae."
                )
            except Exception as e:
                raise e
        logger.debug("Converting MP query results into DataFrame.")
        perovskites_database = convert_my_query_to_dataframe(
            query, mute_progress_bars=mute_progress_bars
        )

        query_ids = list()
        for m in query:
            query_ids.append(m.material_id)

        # Fast cleaning
        logger.debug("Removing NaN (rows)")
        logger.debug(f"size DB before = {len(perovskites_database)}")
        perovskites_database = perovskites_database.dropna(
            axis=0, how="any", subset=BAND_CRITICAL_FIELD
        )
        logger.debug(f"size DB after = {len(perovskites_database)}")
        logger.debug("Removing NaN (cols)")
        logger.debug(f"size DB before = {len(perovskites_database)}")
        perovskites_database = perovskites_database.dropna(axis=1, how="all")
        logger.debug(f"size DB after = {len(perovskites_database)}")

        # Filtering
        logger.debug("Removing metallic perovskites.")
        logger.debug(f"size DB before = {len(perovskites_database)}")
        perovskites_database["is_metal"] = perovskites_database["is_metal"].astype(bool)
        filtered_perov_database = perovskites_database[~(perovskites_database["is_metal"])]
        # filtered_perov_database = perovskites_database
        logger.debug(f"size DB after = {len(filtered_perov_database)}")

        query_ids_filtered = filtered_perov_database["material_id"]
        diff = set(query_ids) - set(query_ids_filtered)

        reach_end = False
        while not reach_end:
            for i, q in enumerate(query):
                if q.material_id in diff:
                    query.pop(i)
                    break
            if i == len(query) - 1:
                reach_end = True

        filtered_perov_database.reset_index(drop=True, inplace=True)
        self._perovskites = filtered_perov_database.copy()

        logger.success("Perovskites retrieved successfully.")
        return self._perovskites, query

    def save_cif_files(
        self,
        stage: str,
        mute_progress_bars: bool = True,
    ) -> None:
        """
        Save CIF files for materials and update the database efficiently.

        This method retrieves crystal structures from the Materials Project API and saves them
        as CIF files in the appropriate directory. It ensures raw data integrity and efficiently
        updates the database with CIF file paths.

        Args:
            stage (str): The processing stage (`raw`, `processed`, `final`).
            mute_progress_bars (bool, optional): If True, disables progress bars. Defaults to True.

        Raises:
            ImmutableRawDataError: If attempting to modify immutable raw data.

        Logs:
            - WARNING: If the CIF directory is being cleaned or if a material ID is missing.
            - ERROR: If an API query fails or CIF file saving encounters an error.
            - INFO: When CIF files are successfully retrieved and saved.
        """
        saving_dir = self.database_directories[stage] / "structures/"
        database = self.get_database(stage)

        # Ensure raw data integrity
        if stage == "raw" and not self._update_raw:
            logger.error("Raw data must be treated as immutable!")
            raise ImmutableRawDataError("Raw data must be treated as immutable!")

        # Clear and recreate directory
        if saving_dir.exists():
            logger.warning(f"Cleaning {saving_dir}")
            sh.rmtree(saving_dir)
        saving_dir.mkdir(parents=True, exist_ok=False)

        # Create a lookup dictionary for material IDs → DataFrame row indices
        material_id_to_index = {mid: idx for idx, mid in enumerate(database["material_id"])}

        logger.debug("MP querying for perovskite structures.")
        mp_api_key = get_mp_api_key()

        with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
            try:
                materials_mp_query = mpr.materials.summary.search(
                    material_ids=list(material_id_to_index.keys()),
                    fields=["material_id", "structure"],
                )
                logger.info(f"MP query successful, {len(materials_mp_query)} structures found.")
            except Exception as e:
                logger.error(f"MP query failed: {e}")
                raise e

        all_cif_paths = {}  # Store updates in a dict to vectorize DataFrame updates later

        # Define a function to save CIF files in parallel
        def save_cif(material):
            try:
                material_id = material.material_id
                if material_id not in material_id_to_index:
                    logger.warning(f"Material ID {material_id} not found in database.")
                    return None

                cif_path = saving_dir / f"{material_id}.cif"
                material.structure.to(filename=str(cif_path))
                return material_id, str(cif_path)

            except Exception as e:
                logger.error(f"Failed to save CIF for {material.material_id}: {e}")
                return None

        # Parallelize CIF saving (adjust max_workers based on your system)
        with ThreadPoolExecutor(max_workers=None) as executor:
            results = list(executor.map(save_cif, materials_mp_query))

        # Collect results in dictionary for bulk DataFrame update
        for result in results:
            if result:
                material_id, cif_path = result
                all_cif_paths[material_id] = cif_path

        # Bulk update DataFrame in one step (vectorized)
        database["cif_path"] = database["material_id"].map(all_cif_paths)

        # Save the updated database
        self.save_database(stage)
        logger.info(f"CIF files for stage '{stage}' saved and database updated successfully.")

    def process_database(
        self,
        band_gap_lower: float,
        band_gap_upper: float,
        inplace: bool = True,
        db: pd.DataFrame = None,
        clean_magnetic: bool = True,
    ) -> pd.DataFrame:
        """
        Process the raw perovskite database to the `processed` stage.

        This method filters materials based on their band gap energy range and optionally
        removes metallic and magnetic materials. The processed data can either be saved
        in place or returned as a separate DataFrame.

        Args:
            band_gap_lower (float): Lower bound of the band gap energy range (in eV).
            band_gap_upper (float): Upper bound of the band gap energy range (in eV).
            inplace (bool, optional): If True, updates the "processed" database in-place.
                If False, returns the processed DataFrame. Defaults to True.
            db (pd.DataFrame, optional): The database to process if `inplace` is False. Defaults to None.
            clean_magnetic (bool, optional): If True, removes magnetic materials. Defaults to True.

        Returns:
            (pd.DataFrame): Processed database if `inplace` is False, otherwise None.

        Raises:
            ValueError: If `inplace` is False but no DataFrame (`db`) is provided.

        Logs:
            - INFO: Steps of the processing, including filtering metallic and magnetic materials.
            - ERROR: If `inplace` is False and no input database is provided.
        """

        if not inplace and db is None:
            logger.error(
                "Invalid input: You must input a pd.DataFrame if 'inplace' is set to True."
            )
            raise ValueError("You must input a pd.DataFrame if 'inplace' is set to True.")

        if inplace:
            raw_db = self.get_database(stage="raw")
        else:
            raw_db = db

        raw_db = raw_db[~(raw_db["is_metal"].astype(bool))]
        logger.info("Removing metallic materials")
        if clean_magnetic:
            temp_db = raw_db[~(raw_db["is_magnetic"].astype(bool))]
            logger.info("Removing magnetic materials")
        else:
            temp_db = raw_db
            logger.info("Keeping magnetic materials")

        logger.info(
            f"Removing materials with bandgap {band_gap_lower} eV < E_g <= {band_gap_upper} eV"
        )
        processed_db = temp_db[
            (temp_db["band_gap"] > band_gap_lower) & (temp_db["band_gap"] <= band_gap_upper)
        ]

        processed_db.reset_index(drop=True, inplace=True)

        if inplace:
            self.databases["processed"] = processed_db.copy()
            self.save_database("processed")
        else:
            return processed_db

    def __repr__(self) -> str:
        """
        Text representation of the PerovskiteDatabase instance.
        Used for `print()` and `str()` calls.

        Returns:
            str: ASCII table representation of the database
        """
        # Gather information about each stage
        data = {
            "Stage": [],
            "Entries": [],
            "Last Modified": [],
            "Size": [],
            "Storage Path": [],
        }

        # Calculate column widths
        widths = [10, 8, 17, 10, 55]

        for stage in self.processing_stages:
            # Get database info
            db = self.databases[stage]
            path = self.database_paths[stage]

            # Get file modification time and size if file exists
            if path.exists():
                modified = path.stat().st_mtime
                modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
                size = path.stat().st_size / 1024  # Convert to KB
                size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
            else:
                modified_time = "Not created"
                size_str = "0 KB"

            path_str = str(path.resolve())
            if len(path_str) > widths[4]:
                path_str = ".." + path_str[len(path_str) - widths[4] + 3 :]

            # Append data
            data["Stage"].append(stage.capitalize())
            data["Entries"].append(len(db))
            data["Last Modified"].append(modified_time)
            data["Size"].append(size_str)
            data["Storage Path"].append(path_str)

        # Create DataFrame
        info_df = pd.DataFrame(data)

        # Text representation for terminal/print
        def create_separator(widths):
            return "+" + "+".join("-" * (w + 1) for w in widths) + "+"

        # Create the text representation
        lines = []

        # Add title
        title = f" {self.__class__.__name__} Summary "
        lines.append(f"\n{title:=^{sum(widths) + len(widths) * 2 + 1}}")

        # Add header
        separator = create_separator(widths)
        lines.append(separator)

        header = (
            "|" + "|".join(f" {col:<{widths[i]}}" for i, col in enumerate(info_df.columns)) + "|"
        )
        lines.append(header)
        lines.append(separator)

        # Add data rows
        for _, row in info_df.iterrows():
            line = "|" + "|".join(f" {str(val):<{widths[i]}}" for i, val in enumerate(row)) + "|"
            lines.append(line)

        # Add bottom separator
        lines.append(separator)

        return "\n".join(lines)

    def _repr_html_(self) -> str:
        """
        HTML representation of the PerovskiteDatabase instance.
        Used for Jupyter notebook display.

        Returns:
            str: HTML representation of the database
        """
        # Gather information about each stage
        data = {
            "Stage": [],
            "Entries": [],
            "Last Modified": [],
            "Size": [],
            "Storage Path": [],
        }

        for stage in self.processing_stages:
            # Get database info
            db = self.databases[stage]
            path = self.database_paths[stage]

            # Get file modification time and size if file exists
            if path.exists():
                modified = path.stat().st_mtime
                modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
                size = path.stat().st_size / 1024  # Convert to KB
                size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
            else:
                modified_time = "Not created"
                size_str = "0 KB"

            # Append data
            data["Stage"].append(stage.capitalize())
            data["Entries"].append(len(db))
            data["Last Modified"].append(modified_time)
            data["Size"].append(size_str)
            data["Storage Path"].append(str(path.resolve()))

        # Create DataFrame
        info_df = pd.DataFrame(data)

        # Generate header row
        header_cells = " ".join(
            f'<th style="padding: 12px 15px; text-align: left;">{col}</th>'
            for col in info_df.columns
        )

        # Generate table rows
        table_rows = ""
        for _, row in info_df.iterrows():
            cells = "".join(f'<td style="padding: 12px 15px;">{val}</td>' for val in row)
            table_rows += f"<tr style='border-bottom: 1px solid #e9ecef;'>{cells}</tr>"

        # Create the complete HTML
        html = (
            """<style>
                @media (prefers-color-scheme: dark) {
                    .database-container { background-color: #1e1e1e !important; }
                    .database-title { color: #e0e0e0 !important; }
                    .database-table { background-color: #2d2d2d !important; }
                    .database-header { background-color: #4a4a4a !important; }
                    .database-cell { border-color: #404040 !important; }
                    .database-info { color: #b0b0b0 !important; }
                }
            </style>"""
            '<div style="font-family: Arial, sans-serif; padding: 20px; background:transparent; '
            'border-radius: 8px;">'
            f'<h3 style="color: #58bac7; margin-bottom: 15px;">{self.__class__.__name__}</h3>'
            '<div style="overflow-x: auto;">'
            '<table class="database-table" style="border-collapse: collapse; width: 100%;'
            ' box-shadow: 0 1px 3px rgba(0,0,0,0.1); background:transparent;">'
            # '<table style="border-collapse: collapse; width: 100%; background-color: white; '
            # 'box-shadow: 0 1px 3px rgba(0,0,0,0.1);">'
            "<thead>"
            f'<tr style="background-color: #58bac7; color: white;">{header_cells}</tr>'
            "</thead>"
            f"<tbody>{table_rows}</tbody>"
            "</table>"
            "</div>"
            '<div style="margin-top: 10px; color: #666; font-size: 1.1em;">'
            "</div>"
            "</div>"
        )
        return html

__init__(name='perovskites', data_dir=DATA_DIR, external_perovproj_path=EXTERNAL_DATA_DIR / Path('perovskites') / Path('perovproject_db.json'))

This constructor sets up the directory structure for storing data across different processing stages (raw, processed, final). It also initializes placeholders for the database paths and data specific to the PerovskiteDatabase class, including a path for the external Perovskite project data.

Parameters:

Name Type Description Default
name str

The name of the database. Defaults to "perovskites".

'perovskites'
data_dir Path or str

Root directory path for storing data. Defaults to DATA_DIR from the configuration.

DATA_DIR
external_perovproj_path Path or str

Path to the external Perovskite Project database file. Defaults to EXTERNAL_DATA_DIR / "perovskites" / "perovproject_db.json".

EXTERNAL_DATA_DIR / Path('perovskites') / Path('perovproject_db.json')

Raises:

Type Description
NotImplementedError

If the specified processing stage is not supported.

ImmutableRawDataError

If attempting to set an unsupported processing stage.

Source code in energy_gnome/dataset/perovskites.py
Python
def __init__(
    self,
    name: str = "perovskites",
    data_dir: Path | str = DATA_DIR,
    external_perovproj_path: Path | str = EXTERNAL_DATA_DIR
    / Path("perovskites")
    / Path("perovproject_db.json"),
):
    """
    Initialize the PerovskiteDatabase with a root data directory and processing stage.

    This constructor sets up the directory structure for storing data across different
    processing stages (`raw`, `processed`, `final`). It also initializes placeholders
    for the database paths and data specific to the `PerovskiteDatabase` class, including
    a path for the external Perovskite project data.

    Args:
        name (str, optional): The name of the database. Defaults to "perovskites".
        data_dir (Path or str, optional): Root directory path for storing data. Defaults
            to `DATA_DIR` from the configuration.
        external_perovproj_path (Path or str, optional): Path to the external Perovskite
            Project database file. Defaults to `EXTERNAL_DATA_DIR / "perovskites" /
            "perovproject_db.json"`.

    Raises:
        NotImplementedError: If the specified processing stage is not supported.
        ImmutableRawDataError: If attempting to set an unsupported processing stage.
    """
    super().__init__(name=name, data_dir=data_dir)
    self._perovskites = pd.DataFrame()
    self.external_perovproj_path: Path | str = external_perovproj_path

__repr__()

Text representation of the PerovskiteDatabase instance. Used for print() and str() calls.

Returns:

Name Type Description
str str

ASCII table representation of the database

Source code in energy_gnome/dataset/perovskites.py
Python
def __repr__(self) -> str:
    """
    Text representation of the PerovskiteDatabase instance.
    Used for `print()` and `str()` calls.

    Returns:
        str: ASCII table representation of the database
    """
    # Gather information about each stage
    data = {
        "Stage": [],
        "Entries": [],
        "Last Modified": [],
        "Size": [],
        "Storage Path": [],
    }

    # Calculate column widths
    widths = [10, 8, 17, 10, 55]

    for stage in self.processing_stages:
        # Get database info
        db = self.databases[stage]
        path = self.database_paths[stage]

        # Get file modification time and size if file exists
        if path.exists():
            modified = path.stat().st_mtime
            modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
            size = path.stat().st_size / 1024  # Convert to KB
            size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
        else:
            modified_time = "Not created"
            size_str = "0 KB"

        path_str = str(path.resolve())
        if len(path_str) > widths[4]:
            path_str = ".." + path_str[len(path_str) - widths[4] + 3 :]

        # Append data
        data["Stage"].append(stage.capitalize())
        data["Entries"].append(len(db))
        data["Last Modified"].append(modified_time)
        data["Size"].append(size_str)
        data["Storage Path"].append(path_str)

    # Create DataFrame
    info_df = pd.DataFrame(data)

    # Text representation for terminal/print
    def create_separator(widths):
        return "+" + "+".join("-" * (w + 1) for w in widths) + "+"

    # Create the text representation
    lines = []

    # Add title
    title = f" {self.__class__.__name__} Summary "
    lines.append(f"\n{title:=^{sum(widths) + len(widths) * 2 + 1}}")

    # Add header
    separator = create_separator(widths)
    lines.append(separator)

    header = (
        "|" + "|".join(f" {col:<{widths[i]}}" for i, col in enumerate(info_df.columns)) + "|"
    )
    lines.append(header)
    lines.append(separator)

    # Add data rows
    for _, row in info_df.iterrows():
        line = "|" + "|".join(f" {str(val):<{widths[i]}}" for i, val in enumerate(row)) + "|"
        lines.append(line)

    # Add bottom separator
    lines.append(separator)

    return "\n".join(lines)

process_database(band_gap_lower, band_gap_upper, inplace=True, db=None, clean_magnetic=True)

Process the raw perovskite database to the processed stage.

This method filters materials based on their band gap energy range and optionally removes metallic and magnetic materials. The processed data can either be saved in place or returned as a separate DataFrame.

Parameters:

Name Type Description Default
band_gap_lower float

Lower bound of the band gap energy range (in eV).

required
band_gap_upper float

Upper bound of the band gap energy range (in eV).

required
inplace bool

If True, updates the "processed" database in-place. If False, returns the processed DataFrame. Defaults to True.

True
db DataFrame

The database to process if inplace is False. Defaults to None.

None
clean_magnetic bool

If True, removes magnetic materials. Defaults to True.

True

Returns:

Type Description
DataFrame

Processed database if inplace is False, otherwise None.

Raises:

Type Description
ValueError

If inplace is False but no DataFrame (db) is provided.

Logs
  • INFO: Steps of the processing, including filtering metallic and magnetic materials.
  • ERROR: If inplace is False and no input database is provided.
Source code in energy_gnome/dataset/perovskites.py
Python
def process_database(
    self,
    band_gap_lower: float,
    band_gap_upper: float,
    inplace: bool = True,
    db: pd.DataFrame = None,
    clean_magnetic: bool = True,
) -> pd.DataFrame:
    """
    Process the raw perovskite database to the `processed` stage.

    This method filters materials based on their band gap energy range and optionally
    removes metallic and magnetic materials. The processed data can either be saved
    in place or returned as a separate DataFrame.

    Args:
        band_gap_lower (float): Lower bound of the band gap energy range (in eV).
        band_gap_upper (float): Upper bound of the band gap energy range (in eV).
        inplace (bool, optional): If True, updates the "processed" database in-place.
            If False, returns the processed DataFrame. Defaults to True.
        db (pd.DataFrame, optional): The database to process if `inplace` is False. Defaults to None.
        clean_magnetic (bool, optional): If True, removes magnetic materials. Defaults to True.

    Returns:
        (pd.DataFrame): Processed database if `inplace` is False, otherwise None.

    Raises:
        ValueError: If `inplace` is False but no DataFrame (`db`) is provided.

    Logs:
        - INFO: Steps of the processing, including filtering metallic and magnetic materials.
        - ERROR: If `inplace` is False and no input database is provided.
    """

    if not inplace and db is None:
        logger.error(
            "Invalid input: You must input a pd.DataFrame if 'inplace' is set to True."
        )
        raise ValueError("You must input a pd.DataFrame if 'inplace' is set to True.")

    if inplace:
        raw_db = self.get_database(stage="raw")
    else:
        raw_db = db

    raw_db = raw_db[~(raw_db["is_metal"].astype(bool))]
    logger.info("Removing metallic materials")
    if clean_magnetic:
        temp_db = raw_db[~(raw_db["is_magnetic"].astype(bool))]
        logger.info("Removing magnetic materials")
    else:
        temp_db = raw_db
        logger.info("Keeping magnetic materials")

    logger.info(
        f"Removing materials with bandgap {band_gap_lower} eV < E_g <= {band_gap_upper} eV"
    )
    processed_db = temp_db[
        (temp_db["band_gap"] > band_gap_lower) & (temp_db["band_gap"] <= band_gap_upper)
    ]

    processed_db.reset_index(drop=True, inplace=True)

    if inplace:
        self.databases["processed"] = processed_db.copy()
        self.save_database("processed")
    else:
        return processed_db

retrieve_materials(mute_progress_bars=True)

Retrieve Perovskite materials from the Materials Project API.

This method connects to the Materials Project API using the MPRester, queries for materials related to Perovskites, and retrieves specified properties. The data is then cleaned by removing rows with missing critical fields and filtering out metallic perovskites. The method returns a cleaned DataFrame of Perovskites.

Parameters:

Name Type Description Default
mute_progress_bars bool

If True, mutes the Material Project API progress bars during the request. Defaults to True.

True

Returns:

Type Description
DataFrame

A DataFrame containing the retrieved and cleaned Perovskite materials.

Raises:

Type Description
Exception

If the API query fails or any issue occurs during data retrieval.

Logs
  • DEBUG: Logs the process of querying and cleaning data.
  • INFO: Logs successful query results and how many Perovskites were retrieved.
  • SUCCESS: Logs the successful retrieval of Perovskite materials.
Source code in energy_gnome/dataset/perovskites.py
Python
def retrieve_materials(self, mute_progress_bars: bool = True) -> pd.DataFrame:
    """
    Retrieve Perovskite materials from the Materials Project API.

    This method connects to the Materials Project API using the `MPRester`, queries
    for materials related to Perovskites, and retrieves specified properties. The data
    is then cleaned by removing rows with missing critical fields and filtering out
    metallic perovskites. The method returns a cleaned DataFrame of Perovskites.

    Args:
        mute_progress_bars (bool, optional):
            If `True`, mutes the Material Project API progress bars during the request.
            Defaults to `True`.

    Returns:
        (pd.DataFrame): A DataFrame containing the retrieved and cleaned Perovskite materials.

    Raises:
        Exception: If the API query fails or any issue occurs during data retrieval.

    Logs:
        - DEBUG: Logs the process of querying and cleaning data.
        - INFO: Logs successful query results and how many Perovskites were retrieved.
        - SUCCESS: Logs the successful retrieval of Perovskite materials.
    """
    mp_api_key = get_mp_api_key()
    ids_list_robo = self._pre_retrieve_robo(mute_progress_bars=mute_progress_bars)
    ids_list_perovproj = self._pre_retrieve_perovproj(mute_progress_bars=mute_progress_bars)
    logger.debug("MP querying for perovskites.")

    ids_list = ids_list_robo + ids_list_perovproj
    unique_ids = list()
    for x in ids_list:
        if x not in unique_ids:
            unique_ids.append(x)

    with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
        try:
            query = mpr.materials.summary.search(
                material_ids=unique_ids, fields=MAT_PROPERTIES
            )
            logger.info(
                f"MP query successful, {len(query)} perovskites found through Robocrystallographer and Perovskite Project formulae."
            )
        except Exception as e:
            raise e
    logger.debug("Converting MP query results into DataFrame.")
    perovskites_database = convert_my_query_to_dataframe(
        query, mute_progress_bars=mute_progress_bars
    )

    query_ids = list()
    for m in query:
        query_ids.append(m.material_id)

    # Fast cleaning
    logger.debug("Removing NaN (rows)")
    logger.debug(f"size DB before = {len(perovskites_database)}")
    perovskites_database = perovskites_database.dropna(
        axis=0, how="any", subset=BAND_CRITICAL_FIELD
    )
    logger.debug(f"size DB after = {len(perovskites_database)}")
    logger.debug("Removing NaN (cols)")
    logger.debug(f"size DB before = {len(perovskites_database)}")
    perovskites_database = perovskites_database.dropna(axis=1, how="all")
    logger.debug(f"size DB after = {len(perovskites_database)}")

    # Filtering
    logger.debug("Removing metallic perovskites.")
    logger.debug(f"size DB before = {len(perovskites_database)}")
    perovskites_database["is_metal"] = perovskites_database["is_metal"].astype(bool)
    filtered_perov_database = perovskites_database[~(perovskites_database["is_metal"])]
    # filtered_perov_database = perovskites_database
    logger.debug(f"size DB after = {len(filtered_perov_database)}")

    query_ids_filtered = filtered_perov_database["material_id"]
    diff = set(query_ids) - set(query_ids_filtered)

    reach_end = False
    while not reach_end:
        for i, q in enumerate(query):
            if q.material_id in diff:
                query.pop(i)
                break
        if i == len(query) - 1:
            reach_end = True

    filtered_perov_database.reset_index(drop=True, inplace=True)
    self._perovskites = filtered_perov_database.copy()

    logger.success("Perovskites retrieved successfully.")
    return self._perovskites, query

save_cif_files(stage, mute_progress_bars=True)

Save CIF files for materials and update the database efficiently.

This method retrieves crystal structures from the Materials Project API and saves them as CIF files in the appropriate directory. It ensures raw data integrity and efficiently updates the database with CIF file paths.

Parameters:

Name Type Description Default
stage str

The processing stage (raw, processed, final).

required
mute_progress_bars bool

If True, disables progress bars. Defaults to True.

True

Raises:

Type Description
ImmutableRawDataError

If attempting to modify immutable raw data.

Logs
  • WARNING: If the CIF directory is being cleaned or if a material ID is missing.
  • ERROR: If an API query fails or CIF file saving encounters an error.
  • INFO: When CIF files are successfully retrieved and saved.
Source code in energy_gnome/dataset/perovskites.py
Python
def save_cif_files(
    self,
    stage: str,
    mute_progress_bars: bool = True,
) -> None:
    """
    Save CIF files for materials and update the database efficiently.

    This method retrieves crystal structures from the Materials Project API and saves them
    as CIF files in the appropriate directory. It ensures raw data integrity and efficiently
    updates the database with CIF file paths.

    Args:
        stage (str): The processing stage (`raw`, `processed`, `final`).
        mute_progress_bars (bool, optional): If True, disables progress bars. Defaults to True.

    Raises:
        ImmutableRawDataError: If attempting to modify immutable raw data.

    Logs:
        - WARNING: If the CIF directory is being cleaned or if a material ID is missing.
        - ERROR: If an API query fails or CIF file saving encounters an error.
        - INFO: When CIF files are successfully retrieved and saved.
    """
    saving_dir = self.database_directories[stage] / "structures/"
    database = self.get_database(stage)

    # Ensure raw data integrity
    if stage == "raw" and not self._update_raw:
        logger.error("Raw data must be treated as immutable!")
        raise ImmutableRawDataError("Raw data must be treated as immutable!")

    # Clear and recreate directory
    if saving_dir.exists():
        logger.warning(f"Cleaning {saving_dir}")
        sh.rmtree(saving_dir)
    saving_dir.mkdir(parents=True, exist_ok=False)

    # Create a lookup dictionary for material IDs → DataFrame row indices
    material_id_to_index = {mid: idx for idx, mid in enumerate(database["material_id"])}

    logger.debug("MP querying for perovskite structures.")
    mp_api_key = get_mp_api_key()

    with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
        try:
            materials_mp_query = mpr.materials.summary.search(
                material_ids=list(material_id_to_index.keys()),
                fields=["material_id", "structure"],
            )
            logger.info(f"MP query successful, {len(materials_mp_query)} structures found.")
        except Exception as e:
            logger.error(f"MP query failed: {e}")
            raise e

    all_cif_paths = {}  # Store updates in a dict to vectorize DataFrame updates later

    # Define a function to save CIF files in parallel
    def save_cif(material):
        try:
            material_id = material.material_id
            if material_id not in material_id_to_index:
                logger.warning(f"Material ID {material_id} not found in database.")
                return None

            cif_path = saving_dir / f"{material_id}.cif"
            material.structure.to(filename=str(cif_path))
            return material_id, str(cif_path)

        except Exception as e:
            logger.error(f"Failed to save CIF for {material.material_id}: {e}")
            return None

    # Parallelize CIF saving (adjust max_workers based on your system)
    with ThreadPoolExecutor(max_workers=None) as executor:
        results = list(executor.map(save_cif, materials_mp_query))

    # Collect results in dictionary for bulk DataFrame update
    for result in results:
        if result:
            material_id, cif_path = result
            all_cif_paths[material_id] = cif_path

    # Bulk update DataFrame in one step (vectorized)
    database["cif_path"] = database["material_id"].map(all_cif_paths)

    # Save the updated database
    self.save_database(stage)
    logger.info(f"CIF files for stage '{stage}' saved and database updated successfully.")

energy_gnome.dataset.MPDatabase

Bases: BaseDatabase

Source code in energy_gnome/dataset/random_mats.py
Python
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
class MPDatabase(BaseDatabase):
    def __init__(self, name: str = "mp", data_dir: Path | str = DATA_DIR):
        """
        Initialize the MPDatabase with a root data directory and processing stage.

        This constructor sets up the directory structure for storing data across different
        processing stages (`raw`, `processed`, `final`). It also initializes placeholders
        for the database paths and data specific to the `MPDatabase` class.

        Args:
            name (str, optional): The name of the database. Defaults to "mp".
            data_dir (Path or str, optional): Root directory path for storing data. Defaults
                to `DATA_DIR` from the configuration.

        Raises:
            NotImplementedError: If the specified processing stage is not supported.
            ImmutableRawDataError: If attempting to set an unsupported processing stage.
        """
        super().__init__(name=name, data_dir=data_dir)

        # Force single directory for raw database of MPDatabase
        self.database_directories["raw"] = self.data_dir / "raw" / "mp"

        self._mp = pd.DataFrame()

    def _set_is_specialized(self):
        """
        Set the `is_specialized` attribute to `False`.

        This method marks the database as specialized by setting the `is_specialized`
        attribute to `False`. It is typically used to indicate that the database
        is intended for a specific class of data, corresponding to specialized
        energy materials.

        Returns:
            None
        """
        self.is_specialized = False

    def retrieve_materials(
        self, max_framework_size: int = 6, mute_progress_bars: bool = True
    ) -> pd.DataFrame:
        """
        Retrieve materials from the Materials Project API.

        This method connects to the Materials Project API using the `MPRester`, queries
        for all materials, and retrieves specified properties. The data is then cleaned
        by removing rows with missing critical fields.
        The method returns a cleaned DataFrame of materials.

        Args:
            mute_progress_bars (bool, optional):
                If `True`, mutes the Material Project API progress bars during the request.
                Defaults to `True`.

        Returns:
            (pd.DataFrame): A DataFrame containing the retrieved and cleaned materials.

        Raises:
            Exception: If the API query fails or any issue occurs during data retrieval.

        Logs:
            - DEBUG: Logs the process of querying and cleaning data.
            - INFO: Logs successful query results and how many materials were retrieved.
            - SUCCESS: Logs the successful retrieval of materials.
        """
        mp_api_key = get_mp_api_key()
        logger.debug("MP querying for all materials.")
        query = []

        with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
            try:
                for n_elm in range(1, max_framework_size + 1):
                    chemsys = "*" + "-*" * n_elm
                    logger.info(f"Retrieving all materials with chemical system = {chemsys} :")
                    query += mpr.materials.summary.search(chemsys=chemsys, fields=MAT_PROPERTIES)
                logger.info(f"MP query successful, {len(query)} materials found.")
            except Exception as e:
                raise e
        logger.debug("Converting MP query results into DataFrame.")
        mp_database = convert_my_query_to_dataframe(query, mute_progress_bars=mute_progress_bars)

        # Fast cleaning
        logger.debug("Removing NaN (rows)")
        logger.debug(f"size DB before = {len(mp_database)}")
        mp_database = mp_database.dropna(axis=0, how="any", subset=CRITICAL_FIELD)
        logger.debug(f"size DB after = {len(mp_database)}")
        logger.debug("Removing NaN (cols)")
        logger.debug(f"size DB before = {len(mp_database)}")
        mp_database = mp_database.dropna(axis=1, how="all")
        logger.debug(f"size DB after = {len(mp_database)}")

        mp_database.reset_index(drop=True, inplace=True)
        self._mp = mp_database.copy()

        logger.success("Materials retrieved successfully.")
        return self._mp, query

    def save_cif_files(
        self,
        stage: str,
        mute_progress_bars: bool = True,
    ) -> None:
        """
        Save CIF files for materials and update the database efficiently.

        This method retrieves crystal structures from the Materials Project API and saves them
        as CIF files in the appropriate directory. It ensures raw data integrity and efficiently
        updates the database with CIF file paths.

        Args:
            stage (str): The processing stage (`raw`, `processed`, `final`).
            mute_progress_bars (bool, optional): If True, disables progress bars. Defaults to True.

        Raises:
            ImmutableRawDataError: If attempting to modify immutable raw data.

        Logs:
            - WARNING: If the CIF directory is being cleaned or if a material ID is missing.
            - ERROR: If an API query fails or CIF file saving encounters an error.
            - INFO: When CIF files are successfully retrieved and saved.
        """

        # Set up directory for saving CIF files
        saving_dir = self.database_directories[stage] / "structures/"
        database = self.get_database(stage)

        # Ensure raw data integrity
        if stage == "raw" and not self._update_raw:
            logger.error("Raw data must be treated as immutable!")
            raise ImmutableRawDataError("Raw data must be treated as immutable!")

        # Clear directory if it exists
        if saving_dir.exists():
            logger.warning(f"Cleaning {saving_dir}")
            sh.rmtree(saving_dir)
        saving_dir.mkdir(parents=True, exist_ok=False)

        # Create a lookup dictionary for material IDs → DataFrame row indices
        material_id_to_index = {mid: idx for idx, mid in enumerate(database["material_id"])}

        # Fetch structures in batches
        ids_list = database["material_id"].tolist()
        n_batch = int(np.ceil(len(ids_list) / MP_BATCH_SIZE))
        mp_api_key = get_mp_api_key()

        logger.debug("MP querying for materials' structures.")
        with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
            all_cif_paths = {}  # Store updates in a dict to vectorize DataFrame updates later

            for i_batch in tqdm(
                range(n_batch), desc="Saving materials", disable=mute_progress_bars
            ):
                i_star = i_batch * MP_BATCH_SIZE
                i_end = min((i_batch + 1) * MP_BATCH_SIZE, len(ids_list))

                try:
                    materials_mp_query = mpr.materials.summary.search(
                        material_ids=ids_list[i_star:i_end], fields=["material_id", "structure"]
                    )
                    logger.info(
                        f"MP query successful, {len(materials_mp_query)} structures found."
                    )
                except Exception as e:
                    logger.error(f"Failed MP query: {e}")
                    raise e

                # Define a function to save CIF files in parallel
                def save_cif(material):
                    try:
                        material_id = material.material_id
                        if material_id not in material_id_to_index:
                            logger.warning(f"Material ID {material_id} not found in database.")
                            return None

                        cif_path = saving_dir / f"{material_id}.cif"
                        material.structure.to(filename=str(cif_path))
                        return material_id, str(cif_path)

                    except Exception as e:
                        logger.error(f"Failed to save CIF for {material.material_id}: {e}")
                        return None

                # Parallelize CIF saving (adjust max_workers based on your system)
                with ThreadPoolExecutor(max_workers=None) as executor:
                    results = list(executor.map(save_cif, materials_mp_query))

                # Collect results in dictionary for bulk DataFrame update
                for result in results:
                    if result:
                        material_id, cif_path = result
                        all_cif_paths[material_id] = cif_path

        # Bulk update DataFrame in one step (vectorized)
        database["cif_path"] = database["material_id"].map(all_cif_paths)

        # Save the updated database
        self.save_database(stage)
        logger.info(f"CIF files for stage '{stage}' saved and database updated successfully.")

    def remove_cross_overlap(
        self,
        stage: str,
        database: pd.DataFrame,
    ) -> pd.DataFrame:
        """
        Remove entries in the generic database that overlap with a category-specific database.

        This function identifies and removes material entries that exist in both the generic
        and category-specific databases for a given processing stage.

        Args:
            stage (str): Processing stage (`raw`, `processed`, `final`).
            database (pd.DataFrame): The category-specific database to compare with the generic database.

        Returns:
            pd.DataFrame: The filtered generic database with overlapping entries removed.

        Raises:
            ValueError: If an invalid processing stage is provided.

        Logs:
            - ERROR: If an invalid stage is given.
            - INFO: Number of overlapping entries identified and removed.
        """
        if stage not in self.processing_stages:
            logger.error(f"Invalid stage: {stage}. Must be one of {self.processing_stages}.")
            raise ValueError(f"stage must be one of {self.processing_stages}.")

        self.load_database(stage)
        mp_database = self.get_database(stage)
        id_overlap = database["material_id"].tolist()

        mp_database["to_drop"] = mp_database.apply(
            lambda x: x["material_id"] in id_overlap, axis=1
        )
        to_drop = mp_database["to_drop"].value_counts().get(True, 0)
        logger.info(f"{to_drop} overlapping items to drop.")
        mp_database_no_overlap = mp_database.iloc[
            np.where(mp_database["to_drop"] == 0)[0], :
        ].reset_index(drop=True)
        mp_database_no_overlap.drop(columns=["to_drop"], inplace=True)

        return mp_database_no_overlap

    def __repr__(self) -> str:
        """
        Text representation of the MPDatabase instance.
        Used for `print()` and `str()` calls.

        Returns:
            str: ASCII table representation of the database
        """
        # Gather information about each stage
        data = {
            "Stage": [],
            "Entries": [],
            "Last Modified": [],
            "Size": [],
            "Storage Path": [],
        }

        # Calculate column widths
        widths = [10, 8, 17, 10, 55]

        for stage in self.processing_stages:
            # Get database info
            db = self.databases[stage]
            path = self.database_paths[stage]

            # Get file modification time and size if file exists
            if path.exists():
                modified = path.stat().st_mtime
                modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
                size = path.stat().st_size / 1024  # Convert to KB
                size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
            else:
                modified_time = "Not created"
                size_str = "0 KB"

            path_str = str(path.resolve())
            if len(path_str) > widths[4]:
                path_str = ".." + path_str[len(path_str) - widths[4] + 3 :]

            # Append data
            data["Stage"].append(stage.capitalize())
            data["Entries"].append(len(db))
            data["Last Modified"].append(modified_time)
            data["Size"].append(size_str)
            data["Storage Path"].append(path_str)

        # Create DataFrame
        info_df = pd.DataFrame(data)

        # Text representation for terminal/print
        def create_separator(widths):
            return "+" + "+".join("-" * (w + 1) for w in widths) + "+"

        # Create the text representation
        lines = []

        # Add title
        title = f" {self.__class__.__name__} Summary "
        lines.append(f"\n{title:=^{sum(widths) + len(widths) * 2 + 1}}")

        # Add header
        separator = create_separator(widths)
        lines.append(separator)

        header = (
            "|" + "|".join(f" {col:<{widths[i]}}" for i, col in enumerate(info_df.columns)) + "|"
        )
        lines.append(header)
        lines.append(separator)

        # Add data rows
        for _, row in info_df.iterrows():
            line = "|" + "|".join(f" {str(val):<{widths[i]}}" for i, val in enumerate(row)) + "|"
            lines.append(line)

        # Add bottom separator
        lines.append(separator)

        return "\n".join(lines)

    def _repr_html_(self) -> str:
        """
        HTML representation of the MPDatabase instance.
        Used for Jupyter notebook display.

        Returns:
            str: HTML representation of the database
        """
        # Gather information about each stage
        data = {
            "Stage": [],
            "Entries": [],
            "Last Modified": [],
            "Size": [],
            "Storage Path": [],
        }

        for stage in self.processing_stages:
            # Get database info
            db = self.databases[stage]
            path = self.database_paths[stage]

            # Get file modification time and size if file exists
            if path.exists():
                modified = path.stat().st_mtime
                modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
                size = path.stat().st_size / 1024  # Convert to KB
                size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
            else:
                modified_time = "Not created"
                size_str = "0 KB"

            # Append data
            data["Stage"].append(stage.capitalize())
            data["Entries"].append(len(db))
            data["Last Modified"].append(modified_time)
            data["Size"].append(size_str)
            data["Storage Path"].append(str(path.resolve()))

        # Create DataFrame
        info_df = pd.DataFrame(data)

        # Generate header row
        header_cells = " ".join(
            f'<th style="padding: 12px 15px; text-align: left;">{col}</th>'
            for col in info_df.columns
        )

        # Generate table rows
        table_rows = ""
        for _, row in info_df.iterrows():
            cells = "".join(f'<td style="padding: 12px 15px;">{val}</td>' for val in row)
            table_rows += f"<tr style='border-bottom: 1px solid #e9ecef;'>{cells}</tr>"

        # Create the complete HTML
        html = (
            """<style>
                @media (prefers-color-scheme: dark) {
                    .database-container { background-color: #1e1e1e !important; }
                    .database-title { color: #e0e0e0 !important; }
                    .database-table { background-color: #2d2d2d !important; }
                    .database-header { background-color: #4a4a4a !important; }
                    .database-cell { border-color: #404040 !important; }
                    .database-info { color: #b0b0b0 !important; }
                }
            </style>"""
            '<div style="font-family: Arial, sans-serif; padding: 20px; background:transparent; '
            'border-radius: 8px;">'
            f'<h3 style="color: #58bac7; margin-bottom: 15px;">{self.__class__.__name__}</h3>'
            '<div style="overflow-x: auto;">'
            '<table class="database-table" style="border-collapse: collapse; width: 100%;'
            ' box-shadow: 0 1px 3px rgba(0,0,0,0.1); background:transparent;">'
            # '<table style="border-collapse: collapse; width: 100%; background-color: white; '
            # 'box-shadow: 0 1px 3px rgba(0,0,0,0.1);">'
            "<thead>"
            f'<tr style="background-color: #58bac7; color: white;">{header_cells}</tr>'
            "</thead>"
            f"<tbody>{table_rows}</tbody>"
            "</table>"
            "</div>"
            '<div style="margin-top: 10px; color: #666; font-size: 1.1em;">'
            "</div>"
            "</div>"
        )
        return html

__init__(name='mp', data_dir=DATA_DIR)

This constructor sets up the directory structure for storing data across different processing stages (raw, processed, final). It also initializes placeholders for the database paths and data specific to the MPDatabase class.

Parameters:

Name Type Description Default
name str

The name of the database. Defaults to "mp".

'mp'
data_dir Path or str

Root directory path for storing data. Defaults to DATA_DIR from the configuration.

DATA_DIR

Raises:

Type Description
NotImplementedError

If the specified processing stage is not supported.

ImmutableRawDataError

If attempting to set an unsupported processing stage.

Source code in energy_gnome/dataset/random_mats.py
Python
def __init__(self, name: str = "mp", data_dir: Path | str = DATA_DIR):
    """
    Initialize the MPDatabase with a root data directory and processing stage.

    This constructor sets up the directory structure for storing data across different
    processing stages (`raw`, `processed`, `final`). It also initializes placeholders
    for the database paths and data specific to the `MPDatabase` class.

    Args:
        name (str, optional): The name of the database. Defaults to "mp".
        data_dir (Path or str, optional): Root directory path for storing data. Defaults
            to `DATA_DIR` from the configuration.

    Raises:
        NotImplementedError: If the specified processing stage is not supported.
        ImmutableRawDataError: If attempting to set an unsupported processing stage.
    """
    super().__init__(name=name, data_dir=data_dir)

    # Force single directory for raw database of MPDatabase
    self.database_directories["raw"] = self.data_dir / "raw" / "mp"

    self._mp = pd.DataFrame()

__repr__()

Text representation of the MPDatabase instance. Used for print() and str() calls.

Returns:

Name Type Description
str str

ASCII table representation of the database

Source code in energy_gnome/dataset/random_mats.py
Python
def __repr__(self) -> str:
    """
    Text representation of the MPDatabase instance.
    Used for `print()` and `str()` calls.

    Returns:
        str: ASCII table representation of the database
    """
    # Gather information about each stage
    data = {
        "Stage": [],
        "Entries": [],
        "Last Modified": [],
        "Size": [],
        "Storage Path": [],
    }

    # Calculate column widths
    widths = [10, 8, 17, 10, 55]

    for stage in self.processing_stages:
        # Get database info
        db = self.databases[stage]
        path = self.database_paths[stage]

        # Get file modification time and size if file exists
        if path.exists():
            modified = path.stat().st_mtime
            modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
            size = path.stat().st_size / 1024  # Convert to KB
            size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
        else:
            modified_time = "Not created"
            size_str = "0 KB"

        path_str = str(path.resolve())
        if len(path_str) > widths[4]:
            path_str = ".." + path_str[len(path_str) - widths[4] + 3 :]

        # Append data
        data["Stage"].append(stage.capitalize())
        data["Entries"].append(len(db))
        data["Last Modified"].append(modified_time)
        data["Size"].append(size_str)
        data["Storage Path"].append(path_str)

    # Create DataFrame
    info_df = pd.DataFrame(data)

    # Text representation for terminal/print
    def create_separator(widths):
        return "+" + "+".join("-" * (w + 1) for w in widths) + "+"

    # Create the text representation
    lines = []

    # Add title
    title = f" {self.__class__.__name__} Summary "
    lines.append(f"\n{title:=^{sum(widths) + len(widths) * 2 + 1}}")

    # Add header
    separator = create_separator(widths)
    lines.append(separator)

    header = (
        "|" + "|".join(f" {col:<{widths[i]}}" for i, col in enumerate(info_df.columns)) + "|"
    )
    lines.append(header)
    lines.append(separator)

    # Add data rows
    for _, row in info_df.iterrows():
        line = "|" + "|".join(f" {str(val):<{widths[i]}}" for i, val in enumerate(row)) + "|"
        lines.append(line)

    # Add bottom separator
    lines.append(separator)

    return "\n".join(lines)

remove_cross_overlap(stage, database)

Remove entries in the generic database that overlap with a category-specific database.

This function identifies and removes material entries that exist in both the generic and category-specific databases for a given processing stage.

Parameters:

Name Type Description Default
stage str

Processing stage (raw, processed, final).

required
database DataFrame

The category-specific database to compare with the generic database.

required

Returns:

Type Description
DataFrame

pd.DataFrame: The filtered generic database with overlapping entries removed.

Raises:

Type Description
ValueError

If an invalid processing stage is provided.

Logs
  • ERROR: If an invalid stage is given.
  • INFO: Number of overlapping entries identified and removed.
Source code in energy_gnome/dataset/random_mats.py
Python
def remove_cross_overlap(
    self,
    stage: str,
    database: pd.DataFrame,
) -> pd.DataFrame:
    """
    Remove entries in the generic database that overlap with a category-specific database.

    This function identifies and removes material entries that exist in both the generic
    and category-specific databases for a given processing stage.

    Args:
        stage (str): Processing stage (`raw`, `processed`, `final`).
        database (pd.DataFrame): The category-specific database to compare with the generic database.

    Returns:
        pd.DataFrame: The filtered generic database with overlapping entries removed.

    Raises:
        ValueError: If an invalid processing stage is provided.

    Logs:
        - ERROR: If an invalid stage is given.
        - INFO: Number of overlapping entries identified and removed.
    """
    if stage not in self.processing_stages:
        logger.error(f"Invalid stage: {stage}. Must be one of {self.processing_stages}.")
        raise ValueError(f"stage must be one of {self.processing_stages}.")

    self.load_database(stage)
    mp_database = self.get_database(stage)
    id_overlap = database["material_id"].tolist()

    mp_database["to_drop"] = mp_database.apply(
        lambda x: x["material_id"] in id_overlap, axis=1
    )
    to_drop = mp_database["to_drop"].value_counts().get(True, 0)
    logger.info(f"{to_drop} overlapping items to drop.")
    mp_database_no_overlap = mp_database.iloc[
        np.where(mp_database["to_drop"] == 0)[0], :
    ].reset_index(drop=True)
    mp_database_no_overlap.drop(columns=["to_drop"], inplace=True)

    return mp_database_no_overlap

retrieve_materials(max_framework_size=6, mute_progress_bars=True)

Retrieve materials from the Materials Project API.

This method connects to the Materials Project API using the MPRester, queries for all materials, and retrieves specified properties. The data is then cleaned by removing rows with missing critical fields. The method returns a cleaned DataFrame of materials.

Parameters:

Name Type Description Default
mute_progress_bars bool

If True, mutes the Material Project API progress bars during the request. Defaults to True.

True

Returns:

Type Description
DataFrame

A DataFrame containing the retrieved and cleaned materials.

Raises:

Type Description
Exception

If the API query fails or any issue occurs during data retrieval.

Logs
  • DEBUG: Logs the process of querying and cleaning data.
  • INFO: Logs successful query results and how many materials were retrieved.
  • SUCCESS: Logs the successful retrieval of materials.
Source code in energy_gnome/dataset/random_mats.py
Python
def retrieve_materials(
    self, max_framework_size: int = 6, mute_progress_bars: bool = True
) -> pd.DataFrame:
    """
    Retrieve materials from the Materials Project API.

    This method connects to the Materials Project API using the `MPRester`, queries
    for all materials, and retrieves specified properties. The data is then cleaned
    by removing rows with missing critical fields.
    The method returns a cleaned DataFrame of materials.

    Args:
        mute_progress_bars (bool, optional):
            If `True`, mutes the Material Project API progress bars during the request.
            Defaults to `True`.

    Returns:
        (pd.DataFrame): A DataFrame containing the retrieved and cleaned materials.

    Raises:
        Exception: If the API query fails or any issue occurs during data retrieval.

    Logs:
        - DEBUG: Logs the process of querying and cleaning data.
        - INFO: Logs successful query results and how many materials were retrieved.
        - SUCCESS: Logs the successful retrieval of materials.
    """
    mp_api_key = get_mp_api_key()
    logger.debug("MP querying for all materials.")
    query = []

    with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
        try:
            for n_elm in range(1, max_framework_size + 1):
                chemsys = "*" + "-*" * n_elm
                logger.info(f"Retrieving all materials with chemical system = {chemsys} :")
                query += mpr.materials.summary.search(chemsys=chemsys, fields=MAT_PROPERTIES)
            logger.info(f"MP query successful, {len(query)} materials found.")
        except Exception as e:
            raise e
    logger.debug("Converting MP query results into DataFrame.")
    mp_database = convert_my_query_to_dataframe(query, mute_progress_bars=mute_progress_bars)

    # Fast cleaning
    logger.debug("Removing NaN (rows)")
    logger.debug(f"size DB before = {len(mp_database)}")
    mp_database = mp_database.dropna(axis=0, how="any", subset=CRITICAL_FIELD)
    logger.debug(f"size DB after = {len(mp_database)}")
    logger.debug("Removing NaN (cols)")
    logger.debug(f"size DB before = {len(mp_database)}")
    mp_database = mp_database.dropna(axis=1, how="all")
    logger.debug(f"size DB after = {len(mp_database)}")

    mp_database.reset_index(drop=True, inplace=True)
    self._mp = mp_database.copy()

    logger.success("Materials retrieved successfully.")
    return self._mp, query

save_cif_files(stage, mute_progress_bars=True)

Save CIF files for materials and update the database efficiently.

This method retrieves crystal structures from the Materials Project API and saves them as CIF files in the appropriate directory. It ensures raw data integrity and efficiently updates the database with CIF file paths.

Parameters:

Name Type Description Default
stage str

The processing stage (raw, processed, final).

required
mute_progress_bars bool

If True, disables progress bars. Defaults to True.

True

Raises:

Type Description
ImmutableRawDataError

If attempting to modify immutable raw data.

Logs
  • WARNING: If the CIF directory is being cleaned or if a material ID is missing.
  • ERROR: If an API query fails or CIF file saving encounters an error.
  • INFO: When CIF files are successfully retrieved and saved.
Source code in energy_gnome/dataset/random_mats.py
Python
def save_cif_files(
    self,
    stage: str,
    mute_progress_bars: bool = True,
) -> None:
    """
    Save CIF files for materials and update the database efficiently.

    This method retrieves crystal structures from the Materials Project API and saves them
    as CIF files in the appropriate directory. It ensures raw data integrity and efficiently
    updates the database with CIF file paths.

    Args:
        stage (str): The processing stage (`raw`, `processed`, `final`).
        mute_progress_bars (bool, optional): If True, disables progress bars. Defaults to True.

    Raises:
        ImmutableRawDataError: If attempting to modify immutable raw data.

    Logs:
        - WARNING: If the CIF directory is being cleaned or if a material ID is missing.
        - ERROR: If an API query fails or CIF file saving encounters an error.
        - INFO: When CIF files are successfully retrieved and saved.
    """

    # Set up directory for saving CIF files
    saving_dir = self.database_directories[stage] / "structures/"
    database = self.get_database(stage)

    # Ensure raw data integrity
    if stage == "raw" and not self._update_raw:
        logger.error("Raw data must be treated as immutable!")
        raise ImmutableRawDataError("Raw data must be treated as immutable!")

    # Clear directory if it exists
    if saving_dir.exists():
        logger.warning(f"Cleaning {saving_dir}")
        sh.rmtree(saving_dir)
    saving_dir.mkdir(parents=True, exist_ok=False)

    # Create a lookup dictionary for material IDs → DataFrame row indices
    material_id_to_index = {mid: idx for idx, mid in enumerate(database["material_id"])}

    # Fetch structures in batches
    ids_list = database["material_id"].tolist()
    n_batch = int(np.ceil(len(ids_list) / MP_BATCH_SIZE))
    mp_api_key = get_mp_api_key()

    logger.debug("MP querying for materials' structures.")
    with MPRester(mp_api_key, mute_progress_bars=mute_progress_bars) as mpr:
        all_cif_paths = {}  # Store updates in a dict to vectorize DataFrame updates later

        for i_batch in tqdm(
            range(n_batch), desc="Saving materials", disable=mute_progress_bars
        ):
            i_star = i_batch * MP_BATCH_SIZE
            i_end = min((i_batch + 1) * MP_BATCH_SIZE, len(ids_list))

            try:
                materials_mp_query = mpr.materials.summary.search(
                    material_ids=ids_list[i_star:i_end], fields=["material_id", "structure"]
                )
                logger.info(
                    f"MP query successful, {len(materials_mp_query)} structures found."
                )
            except Exception as e:
                logger.error(f"Failed MP query: {e}")
                raise e

            # Define a function to save CIF files in parallel
            def save_cif(material):
                try:
                    material_id = material.material_id
                    if material_id not in material_id_to_index:
                        logger.warning(f"Material ID {material_id} not found in database.")
                        return None

                    cif_path = saving_dir / f"{material_id}.cif"
                    material.structure.to(filename=str(cif_path))
                    return material_id, str(cif_path)

                except Exception as e:
                    logger.error(f"Failed to save CIF for {material.material_id}: {e}")
                    return None

            # Parallelize CIF saving (adjust max_workers based on your system)
            with ThreadPoolExecutor(max_workers=None) as executor:
                results = list(executor.map(save_cif, materials_mp_query))

            # Collect results in dictionary for bulk DataFrame update
            for result in results:
                if result:
                    material_id, cif_path = result
                    all_cif_paths[material_id] = cif_path

    # Bulk update DataFrame in one step (vectorized)
    database["cif_path"] = database["material_id"].map(all_cif_paths)

    # Save the updated database
    self.save_database(stage)
    logger.info(f"CIF files for stage '{stage}' saved and database updated successfully.")

energy_gnome.dataset.GNoMEDatabase

Bases: BaseDatabase

Source code in energy_gnome/dataset/gnome.py
Python
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
class GNoMEDatabase(BaseDatabase):
    def __init__(self, name: str = "gnome", data_dir: Path | str = DATA_DIR):
        """
        Initialize the GNoMEDatabase with a root data directory and processing stage.

        This constructor sets up the directory structure for storing data across different
        processing stages (`raw`, `processed`, `final`). It also initializes placeholders
        for the database paths and data specific to the `GNoMEDatabase` class.

        Args:
            name (str, optional): The name of the database. Defaults to "gnome".
            data_dir (Path or str, optional): Root directory path for storing data. Defaults
                to `DATA_DIR` from the configuration.

        Raises:
            NotImplementedError: If the specified processing stage is not supported.
            ImmutableRawDataError: If attempting to set an unsupported processing stage.
        """
        super().__init__(name=name, data_dir=data_dir)

        # Force single directory for raw database of GNoMEDatabase
        self.database_directories["raw"] = self.data_dir / "raw" / "gnome"

        self._gnome = pd.DataFrame()

    def _set_is_specialized(
        self,
    ):
        pass

    def retrieve_materials(self) -> pd.DataFrame:
        """
        TBD (after implementing the fetch routine)
        """
        csv_db = pd.read_csv(GNoME_DATA_DIR / "stable_materials_summary.csv", index_col=0)
        csv_db = csv_db.rename(columns={"MaterialId": "material_id"})

        return csv_db

    def save_cif_files(self) -> None:
        """
        Save CIF files for materials and update the database accordingly.
        Uses OS-native unzippers for maximum speed.
        """
        zip_path = GNoME_DATA_DIR / "by_id.zip"
        output_path = self.database_directories["raw"] / "structures"
        output_path.mkdir(parents=True, exist_ok=True)

        logger.info("Unzipping structures from")
        logger.info(f"{zip_path}")
        logger.info("to")
        logger.info(f"{output_path}")
        logger.info("using native OS unzip...")

        # OS-native unzip
        if os.name == "nt":  # Windows
            unzip_command = ["tar", "-xf", zip_path, "-C", output_path]
        else:  # Linux/macOS
            unzip_command = ["unzip", "-o", zip_path, "-d", output_path]

        try:
            subprocess.run(unzip_command, check=True, capture_output=True, text=True)
            logger.info("Extraction complete!")
        except subprocess.CalledProcessError as e:
            logger.error(f"Error unzipping CIF file: {e.stderr}")
            return

        # Check if "by_id" subfolder exists
        extracted_dir = output_path / "by_id"
        if extracted_dir.exists() and extracted_dir.is_dir():
            logger.info("Flattening extracted files by moving contents from")
            logger.info(f"{extracted_dir}")
            logger.info("to")
            logger.info(f"{output_path}")

            if os.name == "nt":
                # Windows - use native move command
                move_command = ["robocopy", str(extracted_dir), str(output_path), "/move", "/e"]
            else:
                # Linux/macOS - use `rsync` or `mv` for efficiency
                move_command = ["mv", str(extracted_dir) + "/*", str(output_path)]

            try:
                subprocess.run(
                    move_command, check=True, shell=True, capture_output=True, text=True
                )
                extracted_dir.rmdir()  # Remove the now-empty folder
            except subprocess.CalledProcessError as e:
                logger.error(f"Error moving files: {e.stderr}")
                return

        # Update Database
        df = self.get_database("raw")
        df["cif_path"] = (
            df["material_id"].astype(str).apply(lambda x: (output_path / f"{x}.CIF").as_posix())
        )

        self.save_database("raw")
        logger.info("CIF files saved and database updated successfully.")

    def copy_cif_files(self):
        pass

    '''
    def filter_by_elements(
            self,
            include: list[str] = None,
            exclude: list[str] = None
        ) -> pd.DataFrame:
        """
        Filters the database entries based on the presence or absence of specified chemical elements.

        Args:
            include (list[str], optional): A list of chemical elements that must be present in the composition.
                                        If None, no filtering is applied based on inclusion.
            exclude (list[str], optional): A list of chemical elements that must not be present in the composition.
                                        If None, no filtering is applied based on exclusion.

        Returns:
            (pd.DataFrame): A filtered DataFrame containing only the entries that match the inclusion/exclusion
                            criteria.

        Raises:
            ValueError: If both `include` and `exclude` lists are empty.

        Notes:
            - If both `include` and `exclude` are provided, the function will return entries that contain at least one
            of the `include` elements but none of the `exclude` elements.
        """
        if include is None:
            include = []
        if exclude is None:
            exclude = []

        if not include and not exclude:
            raise ValueError("At least one of `include` or `exclude` must be specified.")

        df = self.get_database("final")

        def contains_elements(elements_str, elements_list: list[str]) -> bool:
            """Checks if at least one element from elements_list is exactly present in Elements."""
            if isinstance(elements_str, str):  # Convert string to list if needed
                try:
                    elements = eval(elements_str)  # Safely parse string to list
                    if not isinstance(elements, list):
                        return False
                except:
                    return False
            else:
                elements = elements_str  # If it's already a list, use it as is

            return bool(set(elements) & set(elements_list))  # Exact match check

        # Apply filtering
        if include:
            df = df[df["Elements"].apply(lambda x: contains_elements(x, include))]
        if exclude:
            df = df[~df["Elements"].apply(lambda x: contains_elements(x, exclude))]

        return df'
    '''

    def filter_by_elements(
        self,
        include: list[str] = None,
        exclude: list[str] = None,
        stage: str = "final",
        save_filtered_db: bool = False,
    ) -> pd.DataFrame:
        """
        Filters the database entries based on the presence or absence of specified chemical elements or element groups.

        Args:
            include (list[str], optional): A list of chemical elements that must be present in the composition.
                                           - If None, no filtering is applied based on inclusion.
                                           - If an element group is passed using the format `"A-B"`, the material must
                                             contain *all* elements in that group.
            exclude (list[str], optional): A list of chemical elements that must not be present in the composition.
                                           - If None, no filtering is applied based on exclusion.
                                           - If an element group is passed using the format `"A-B"`, the material is
                                             removed *only* if it contains *all* elements in that group.
            stage (str, optional): The processing stage to retrieve the database from. Defaults to `"final"`.
                                   Possible values: `"raw"`, `"processed"`, `"final"`.
            save_filtered_db (bool, optional): If True, saves the filtered database back to `self.databases[stage]`
                                               and updates the stored database. Defaults to False.

        Returns:
            (pd.DataFrame): A filtered DataFrame containing only the entries that match the inclusion/exclusion
                            criteria.

        Raises:
            ValueError: If both `include` and `exclude` lists are empty.

        Notes:
            - If both `include` and `exclude` are provided, the function will return entries that contain at least one
              of the `include` elements but none of the `exclude` elements.
            - If an entry in `include` is in the format `"A-B"`, the material must contain all elements in that group.
            - If an entry in `exclude` is in the format `"A-B"`, the material is removed only if it contains *all*
              elements in that group.
            - If `save_filtered_db` is True, the filtered DataFrame is stored in `self.databases[stage]` and saved
              persistently.
        """

        if include is None:
            include = []
        if exclude is None:
            exclude = []

        if not include and not exclude:
            raise ValueError("At least one of `include` or `exclude` must be specified.")

        df = self.get_database(stage)

        def parse_elements(elements_str):
            """Convert the Elements column from string to an actual list."""
            if isinstance(elements_str, str):
                try:
                    elements = eval(elements_str)  # Convert string to list
                    if isinstance(elements, list):
                        return elements
                except Exception as e:
                    logger.warning(f"{e}")
            return elements_str if isinstance(elements_str, list) else []

        def contains_elements(elements, elements_list: list[str]) -> bool:
            """Check if an entry satisfies the element inclusion criteria."""
            material_elements = set(parse_elements(elements))  # Convert material elements to a set

            simple_elements = {
                e for e in elements_list if "-" not in e
            }  # Elements that can appear individually
            grouped_elements = [
                set(e.split("-")) for e in elements_list if "-" in e
            ]  # Element groups that must all be present

            # Check if any simple element is present
            simple_match = bool(material_elements & simple_elements) if simple_elements else False

            # Check if all elements in at least one group are present
            grouped_match = (
                any(group.issubset(material_elements) for group in grouped_elements)
                if grouped_elements
                else False
            )

            return (
                simple_match or grouped_match
            )  # Material passes if it satisfies either condition

        # Apply filtering
        if include:
            df = df[df["Elements"].apply(lambda x: contains_elements(x, include))]
        if exclude:
            df = df[~df["Elements"].apply(lambda x: contains_elements(x, exclude))]

        if save_filtered_db:
            self.databases[stage] = df
            self.save_database[stage]
            logger.info(f"Saved filtered database in stage {stage}.")

        return df

    def __repr__(self) -> str:
        """
        Text representation of the GNoMEDatabase instance.
        Used for `print()` and `str()` calls.

        Returns:
            str: ASCII table representation of the database
        """
        # Gather information about each stage
        data = {
            "Stage": [],
            "Entries": [],
            "Last Modified": [],
            "Size": [],
            "Storage Path": [],
        }

        # Calculate column widths
        widths = [10, 8, 17, 10, 55]

        for stage in self.processing_stages:
            # Get database info
            db = self.databases[stage]
            path = self.database_paths[stage]

            # Get file modification time and size if file exists
            if path.exists():
                modified = path.stat().st_mtime
                modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
                size = path.stat().st_size / 1024  # Convert to KB
                size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
            else:
                modified_time = "Not created"
                size_str = "0 KB"

            path_str = str(path.resolve())
            if len(path_str) > widths[4]:
                path_str = ".." + path_str[len(path_str) - widths[4] + 3 :]

            # Append data
            data["Stage"].append(stage.capitalize())
            data["Entries"].append(len(db))
            data["Last Modified"].append(modified_time)
            data["Size"].append(size_str)
            data["Storage Path"].append(path_str)

        # Create DataFrame
        info_df = pd.DataFrame(data)

        # Text representation for terminal/print
        def create_separator(widths):
            return "+" + "+".join("-" * (w + 1) for w in widths) + "+"

        # Create the text representation
        lines = []

        # Add title
        title = f" {self.__class__.__name__} Summary "
        lines.append(f"\n{title:=^{sum(widths) + len(widths) * 2 + 1}}")

        # Add header
        separator = create_separator(widths)
        lines.append(separator)

        header = (
            "|" + "|".join(f" {col:<{widths[i]}}" for i, col in enumerate(info_df.columns)) + "|"
        )
        lines.append(header)
        lines.append(separator)

        # Add data rows
        for _, row in info_df.iterrows():
            line = "|" + "|".join(f" {str(val):<{widths[i]}}" for i, val in enumerate(row)) + "|"
            lines.append(line)

        # Add bottom separator
        lines.append(separator)

        return "\n".join(lines)

    def _repr_html_(self) -> str:
        """
        HTML representation of the GNoMEDatabase instance.
        Used for Jupyter notebook display.

        Returns:
            str: HTML representation of the database
        """
        # Gather information about each stage
        data = {
            "Stage": [],
            "Entries": [],
            "Last Modified": [],
            "Size": [],
            "Storage Path": [],
        }

        for stage in self.processing_stages:
            # Get database info
            db = self.databases[stage]
            path = self.database_paths[stage]

            # Get file modification time and size if file exists
            if path.exists():
                modified = path.stat().st_mtime
                modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
                size = path.stat().st_size / 1024  # Convert to KB
                size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
            else:
                modified_time = "Not created"
                size_str = "0 KB"

            # Append data
            data["Stage"].append(stage.capitalize())
            data["Entries"].append(len(db))
            data["Last Modified"].append(modified_time)
            data["Size"].append(size_str)
            data["Storage Path"].append(str(path.resolve()))

        # Create DataFrame
        info_df = pd.DataFrame(data)

        # Generate header row
        header_cells = " ".join(
            f'<th style="padding: 12px 15px; text-align: left;">{col}</th>'
            for col in info_df.columns
        )

        # Generate table rows
        table_rows = ""
        for _, row in info_df.iterrows():
            cells = "".join(f'<td style="padding: 12px 15px;">{val}</td>' for val in row)
            table_rows += f"<tr style='border-bottom: 1px solid #e9ecef;'>{cells}</tr>"

        # Create the complete HTML
        html = (
            """<style>
                @media (prefers-color-scheme: dark) {
                    .database-container { background-color: #1e1e1e !important; }
                    .database-title { color: #e0e0e0 !important; }
                    .database-table { background-color: #2d2d2d !important; }
                    .database-header { background-color: #4a4a4a !important; }
                    .database-cell { border-color: #404040 !important; }
                    .database-info { color: #b0b0b0 !important; }
                }
            </style>"""
            '<div style="font-family: Arial, sans-serif; padding: 20px; background:transparent; '
            'border-radius: 8px;">'
            f'<h3 style="color: #58bac7; margin-bottom: 15px;">{self.__class__.__name__}</h3>'
            '<div style="overflow-x: auto;">'
            '<table class="database-table" style="border-collapse: collapse; width: 100%;'
            ' box-shadow: 0 1px 3px rgba(0,0,0,0.1); background:transparent;">'
            # '<table style="border-collapse: collapse; width: 100%; background-color: white; '
            # 'box-shadow: 0 1px 3px rgba(0,0,0,0.1);">'
            "<thead>"
            f'<tr style="background-color: #58bac7; color: white;">{header_cells}</tr>'
            "</thead>"
            f"<tbody>{table_rows}</tbody>"
            "</table>"
            "</div>"
            '<div style="margin-top: 10px; color: #666; font-size: 1.1em;">'
            "</div>"
            "</div>"
        )
        return html

__init__(name='gnome', data_dir=DATA_DIR)

This constructor sets up the directory structure for storing data across different processing stages (raw, processed, final). It also initializes placeholders for the database paths and data specific to the GNoMEDatabase class.

Parameters:

Name Type Description Default
name str

The name of the database. Defaults to "gnome".

'gnome'
data_dir Path or str

Root directory path for storing data. Defaults to DATA_DIR from the configuration.

DATA_DIR

Raises:

Type Description
NotImplementedError

If the specified processing stage is not supported.

ImmutableRawDataError

If attempting to set an unsupported processing stage.

Source code in energy_gnome/dataset/gnome.py
Python
def __init__(self, name: str = "gnome", data_dir: Path | str = DATA_DIR):
    """
    Initialize the GNoMEDatabase with a root data directory and processing stage.

    This constructor sets up the directory structure for storing data across different
    processing stages (`raw`, `processed`, `final`). It also initializes placeholders
    for the database paths and data specific to the `GNoMEDatabase` class.

    Args:
        name (str, optional): The name of the database. Defaults to "gnome".
        data_dir (Path or str, optional): Root directory path for storing data. Defaults
            to `DATA_DIR` from the configuration.

    Raises:
        NotImplementedError: If the specified processing stage is not supported.
        ImmutableRawDataError: If attempting to set an unsupported processing stage.
    """
    super().__init__(name=name, data_dir=data_dir)

    # Force single directory for raw database of GNoMEDatabase
    self.database_directories["raw"] = self.data_dir / "raw" / "gnome"

    self._gnome = pd.DataFrame()

__repr__()

Text representation of the GNoMEDatabase instance. Used for print() and str() calls.

Returns:

Name Type Description
str str

ASCII table representation of the database

Source code in energy_gnome/dataset/gnome.py
Python
def __repr__(self) -> str:
    """
    Text representation of the GNoMEDatabase instance.
    Used for `print()` and `str()` calls.

    Returns:
        str: ASCII table representation of the database
    """
    # Gather information about each stage
    data = {
        "Stage": [],
        "Entries": [],
        "Last Modified": [],
        "Size": [],
        "Storage Path": [],
    }

    # Calculate column widths
    widths = [10, 8, 17, 10, 55]

    for stage in self.processing_stages:
        # Get database info
        db = self.databases[stage]
        path = self.database_paths[stage]

        # Get file modification time and size if file exists
        if path.exists():
            modified = path.stat().st_mtime
            modified_time = pd.Timestamp.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M")
            size = path.stat().st_size / 1024  # Convert to KB
            size_str = f"{size:.1f} KB" if size < 1024 else f"{size / 1024:.1f} MB"
        else:
            modified_time = "Not created"
            size_str = "0 KB"

        path_str = str(path.resolve())
        if len(path_str) > widths[4]:
            path_str = ".." + path_str[len(path_str) - widths[4] + 3 :]

        # Append data
        data["Stage"].append(stage.capitalize())
        data["Entries"].append(len(db))
        data["Last Modified"].append(modified_time)
        data["Size"].append(size_str)
        data["Storage Path"].append(path_str)

    # Create DataFrame
    info_df = pd.DataFrame(data)

    # Text representation for terminal/print
    def create_separator(widths):
        return "+" + "+".join("-" * (w + 1) for w in widths) + "+"

    # Create the text representation
    lines = []

    # Add title
    title = f" {self.__class__.__name__} Summary "
    lines.append(f"\n{title:=^{sum(widths) + len(widths) * 2 + 1}}")

    # Add header
    separator = create_separator(widths)
    lines.append(separator)

    header = (
        "|" + "|".join(f" {col:<{widths[i]}}" for i, col in enumerate(info_df.columns)) + "|"
    )
    lines.append(header)
    lines.append(separator)

    # Add data rows
    for _, row in info_df.iterrows():
        line = "|" + "|".join(f" {str(val):<{widths[i]}}" for i, val in enumerate(row)) + "|"
        lines.append(line)

    # Add bottom separator
    lines.append(separator)

    return "\n".join(lines)

filter_by_elements(include=None, exclude=None, stage='final', save_filtered_db=False)

Filters the database entries based on the presence or absence of specified chemical elements or element groups.

Parameters:

Name Type Description Default
include list[str]

A list of chemical elements that must be present in the composition. - If None, no filtering is applied based on inclusion. - If an element group is passed using the format "A-B", the material must contain all elements in that group.

None
exclude list[str]

A list of chemical elements that must not be present in the composition. - If None, no filtering is applied based on exclusion. - If an element group is passed using the format "A-B", the material is removed only if it contains all elements in that group.

None
stage str

The processing stage to retrieve the database from. Defaults to "final". Possible values: "raw", "processed", "final".

'final'
save_filtered_db bool

If True, saves the filtered database back to self.databases[stage] and updates the stored database. Defaults to False.

False

Returns:

Type Description
DataFrame

A filtered DataFrame containing only the entries that match the inclusion/exclusion criteria.

Raises:

Type Description
ValueError

If both include and exclude lists are empty.

Notes
  • If both include and exclude are provided, the function will return entries that contain at least one of the include elements but none of the exclude elements.
  • If an entry in include is in the format "A-B", the material must contain all elements in that group.
  • If an entry in exclude is in the format "A-B", the material is removed only if it contains all elements in that group.
  • If save_filtered_db is True, the filtered DataFrame is stored in self.databases[stage] and saved persistently.
Source code in energy_gnome/dataset/gnome.py
Python
def filter_by_elements(
    self,
    include: list[str] = None,
    exclude: list[str] = None,
    stage: str = "final",
    save_filtered_db: bool = False,
) -> pd.DataFrame:
    """
    Filters the database entries based on the presence or absence of specified chemical elements or element groups.

    Args:
        include (list[str], optional): A list of chemical elements that must be present in the composition.
                                       - If None, no filtering is applied based on inclusion.
                                       - If an element group is passed using the format `"A-B"`, the material must
                                         contain *all* elements in that group.
        exclude (list[str], optional): A list of chemical elements that must not be present in the composition.
                                       - If None, no filtering is applied based on exclusion.
                                       - If an element group is passed using the format `"A-B"`, the material is
                                         removed *only* if it contains *all* elements in that group.
        stage (str, optional): The processing stage to retrieve the database from. Defaults to `"final"`.
                               Possible values: `"raw"`, `"processed"`, `"final"`.
        save_filtered_db (bool, optional): If True, saves the filtered database back to `self.databases[stage]`
                                           and updates the stored database. Defaults to False.

    Returns:
        (pd.DataFrame): A filtered DataFrame containing only the entries that match the inclusion/exclusion
                        criteria.

    Raises:
        ValueError: If both `include` and `exclude` lists are empty.

    Notes:
        - If both `include` and `exclude` are provided, the function will return entries that contain at least one
          of the `include` elements but none of the `exclude` elements.
        - If an entry in `include` is in the format `"A-B"`, the material must contain all elements in that group.
        - If an entry in `exclude` is in the format `"A-B"`, the material is removed only if it contains *all*
          elements in that group.
        - If `save_filtered_db` is True, the filtered DataFrame is stored in `self.databases[stage]` and saved
          persistently.
    """

    if include is None:
        include = []
    if exclude is None:
        exclude = []

    if not include and not exclude:
        raise ValueError("At least one of `include` or `exclude` must be specified.")

    df = self.get_database(stage)

    def parse_elements(elements_str):
        """Convert the Elements column from string to an actual list."""
        if isinstance(elements_str, str):
            try:
                elements = eval(elements_str)  # Convert string to list
                if isinstance(elements, list):
                    return elements
            except Exception as e:
                logger.warning(f"{e}")
        return elements_str if isinstance(elements_str, list) else []

    def contains_elements(elements, elements_list: list[str]) -> bool:
        """Check if an entry satisfies the element inclusion criteria."""
        material_elements = set(parse_elements(elements))  # Convert material elements to a set

        simple_elements = {
            e for e in elements_list if "-" not in e
        }  # Elements that can appear individually
        grouped_elements = [
            set(e.split("-")) for e in elements_list if "-" in e
        ]  # Element groups that must all be present

        # Check if any simple element is present
        simple_match = bool(material_elements & simple_elements) if simple_elements else False

        # Check if all elements in at least one group are present
        grouped_match = (
            any(group.issubset(material_elements) for group in grouped_elements)
            if grouped_elements
            else False
        )

        return (
            simple_match or grouped_match
        )  # Material passes if it satisfies either condition

    # Apply filtering
    if include:
        df = df[df["Elements"].apply(lambda x: contains_elements(x, include))]
    if exclude:
        df = df[~df["Elements"].apply(lambda x: contains_elements(x, exclude))]

    if save_filtered_db:
        self.databases[stage] = df
        self.save_database[stage]
        logger.info(f"Saved filtered database in stage {stage}.")

    return df

retrieve_materials()

TBD (after implementing the fetch routine)

Source code in energy_gnome/dataset/gnome.py
Python
def retrieve_materials(self) -> pd.DataFrame:
    """
    TBD (after implementing the fetch routine)
    """
    csv_db = pd.read_csv(GNoME_DATA_DIR / "stable_materials_summary.csv", index_col=0)
    csv_db = csv_db.rename(columns={"MaterialId": "material_id"})

    return csv_db

save_cif_files()

Save CIF files for materials and update the database accordingly. Uses OS-native unzippers for maximum speed.

Source code in energy_gnome/dataset/gnome.py
Python
def save_cif_files(self) -> None:
    """
    Save CIF files for materials and update the database accordingly.
    Uses OS-native unzippers for maximum speed.
    """
    zip_path = GNoME_DATA_DIR / "by_id.zip"
    output_path = self.database_directories["raw"] / "structures"
    output_path.mkdir(parents=True, exist_ok=True)

    logger.info("Unzipping structures from")
    logger.info(f"{zip_path}")
    logger.info("to")
    logger.info(f"{output_path}")
    logger.info("using native OS unzip...")

    # OS-native unzip
    if os.name == "nt":  # Windows
        unzip_command = ["tar", "-xf", zip_path, "-C", output_path]
    else:  # Linux/macOS
        unzip_command = ["unzip", "-o", zip_path, "-d", output_path]

    try:
        subprocess.run(unzip_command, check=True, capture_output=True, text=True)
        logger.info("Extraction complete!")
    except subprocess.CalledProcessError as e:
        logger.error(f"Error unzipping CIF file: {e.stderr}")
        return

    # Check if "by_id" subfolder exists
    extracted_dir = output_path / "by_id"
    if extracted_dir.exists() and extracted_dir.is_dir():
        logger.info("Flattening extracted files by moving contents from")
        logger.info(f"{extracted_dir}")
        logger.info("to")
        logger.info(f"{output_path}")

        if os.name == "nt":
            # Windows - use native move command
            move_command = ["robocopy", str(extracted_dir), str(output_path), "/move", "/e"]
        else:
            # Linux/macOS - use `rsync` or `mv` for efficiency
            move_command = ["mv", str(extracted_dir) + "/*", str(output_path)]

        try:
            subprocess.run(
                move_command, check=True, shell=True, capture_output=True, text=True
            )
            extracted_dir.rmdir()  # Remove the now-empty folder
        except subprocess.CalledProcessError as e:
            logger.error(f"Error moving files: {e.stderr}")
            return

    # Update Database
    df = self.get_database("raw")
    df["cif_path"] = (
        df["material_id"].astype(str).apply(lambda x: (output_path / f"{x}.CIF").as_posix())
    )

    self.save_database("raw")
    logger.info("CIF files saved and database updated successfully.")