Source code for balance.testutil

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2.

# pyre-unsafe

import io
import re
import sys

import unittest
from contextlib import contextmanager
from typing import Any, Union

import numpy as np
import numpy.typing as npt
import pandas as pd


def _assert_frame_equal_lazy(
    x: pd.DataFrame, y: pd.DataFrame, lazy: bool = True
) -> None:
    """Wrapper around pd.testing.assert_frame_equal, which transforms the
    dataframes to ignore some errors.

    Ignores order of columns

    Args:
        x (pd.DataFrame): DataFrame to compare
        y (pd.DataFrame): DataFrame to compare
        lazy (bool, optional): Should Ignores be applied. Defaults to True.

    Returns:
        None.
    """
    if lazy:
        x = x.sort_index(axis=0).sort_index(axis=1)
        y = y.sort_index(axis=0).sort_index(axis=1)

    return pd.testing.assert_frame_equal(x, y)


def _assert_index_equal_lazy(x: pd.Index, y: pd.Index, lazy: bool = True) -> None:
    """
    Wrapper around pd.testing.assert_index_equal which transforms the
    dataindexs to ignore some errors.

    Ignores:
        - order of entries

    Args:
        x (pd.Index): Index to compare
        y (pd.Index): Index to compare
        lazy (bool, optional): Should Ignores be applied. Defaults to True.
    """
    if lazy:
        x = x.sort_values()
        y = y.sort_values()

    return pd.testing.assert_index_equal(x, y)


@contextmanager
def _capture_output():
    redirect_out, redirect_err = io.StringIO(), io.StringIO()
    original_out, original_err = sys.stdout, sys.stderr
    try:
        sys.stdout, sys.stderr = redirect_out, redirect_err
        yield sys.stdout, sys.stderr
    finally:
        sys.stdout, sys.stderr = original_out, original_err


[docs] class BalanceTestCase(unittest.TestCase): # Some Warns def assertIfWarns(self, callable, *args, **kwargs) -> None: with self.assertLogs(level="NOTSET") as cm: callable(*args, **kwargs) self.assertTrue(len(cm.output) > 0, "No warning produced.") def assertNotWarns(self, callable, *args, **kwargs) -> None: output = None try: with self.assertLogs() as cm: callable(*args, **kwargs) output = cm except AssertionError: return raise AssertionError(f"Warning produced {output.output}.") def assertWarnsRegexp(self, regexp, callable, *args, **kwargs) -> None: with self.assertLogs(level="NOTSET") as cm: callable(*args, **kwargs) self.assertTrue( any((re.search(regexp, c) is not None) for c in cm.output), f"Warning {cm.output} does not match regex {regexp}.", ) def assertNotWarnsRegexp(self, regexp, callable, *args, **kwargs) -> None: with self.assertLogs(level="NOTSET") as cm: callable(*args, **kwargs) self.assertFalse( any((re.search(regexp, c) is not None) for c in cm.output), f"Warning {cm.output} matches regex {regexp}.", ) # Some Equal
[docs] def assertEqual( self, first: Union[npt.NDArray, pd.DataFrame, pd.Index, pd.Series, Any], second: Union[npt.NDArray, pd.DataFrame, pd.Index, pd.Series, Any], msg: Any = ..., **kwargs, ) -> None: """ Check if first and second are equal. Uses np.testing.assert_array_equal for np.ndarray, _assert_frame_equal_lazy for pd.DataFrame, assert_series_equal for pd.DataFrame, _assert_index_equal_lazy for pd.Index, or unittest.TestCase.assertEqual otherwise. Args: first (Union[np.ndarray, pd.DataFrame, pd.Index, pd.Series]): first element to compare. second (Union[np.ndarray, pd.DataFrame, pd.Index, pd.Series]): second element to compare. msg (Any, optional): The error message on failure. """ lazy: bool = kwargs.get("lazy", False) if isinstance(first, np.ndarray) or isinstance(second, np.ndarray): np.testing.assert_array_equal(first, second, **kwargs) elif isinstance(first, pd.DataFrame) or isinstance(second, pd.DataFrame): _assert_frame_equal_lazy( first, second, lazy, ) elif isinstance(first, pd.Series) or isinstance(second, pd.Series): pd.testing.assert_series_equal(first, second) elif isinstance(first, pd.Index) or isinstance(second, pd.Index): _assert_index_equal_lazy(first, second, lazy) else: super().assertEqual(first, second, msg=msg, **kwargs)
# Some Prints def assertPrints(self, callable, *args, **kwargs) -> None: with _capture_output() as (out, err): callable(*args, **kwargs) out, err = out.getvalue(), err.getvalue() self.assertTrue((len(out) + len(err)) > 0, "No printed output.") def assertNotPrints(self, callable, *args, **kwargs) -> None: with _capture_output() as (out, err): callable(*args, **kwargs) out, err = out.getvalue(), err.getvalue() self.assertTrue( (len(out) + len(err)) == 0, f"Printed output is longer than 0: {(out, err)}.", ) def assertPrintsRegexp(self, regexp, callable, *args, **kwargs) -> None: with _capture_output() as (out, err): callable(*args, **kwargs) out, err = out.getvalue(), err.getvalue() self.assertTrue( any((re.search(regexp, o) is not None) for o in (out, err)), f"Printed output {(out, err)} does not match regex {regexp}.", ) def assertNotPrintsRegexp(self, regexp, callable, *args, **kwargs) -> None: with _capture_output() as (out, err): callable(*args, **kwargs) out, err = out.getvalue(), err.getvalue() self.assertFalse( any((re.search(regexp, o) is not None) for o in (out, err)), f"Printed output {(out, err)} matches regex {regexp}.", )