Source code for unyt.mpl_interface

"""
Matplotlib offers support for custom classes, such as unyt_array, allowing customization
of axis information and unit conversion. In the case of unyt, the axis label is set
based on the unyt_array.name and unyt_array.units attributes. It is also possible to
convert the plotted units.

This feature is optional and has to be enabled using the matplotlib_support context
manager.
"""

# -----------------------------------------------------------------------------
# Copyright (c) 2020, yt Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the LICENSE file, distributed with this software.
# -----------------------------------------------------------------------------


try:
    from matplotlib.units import AxisInfo, ConversionInterface, registry
except ImportError:
    pass
else:
    from weakref import WeakKeyDictionary

    from unyt import Unit, unyt_array, unyt_quantity

    __all__ = ["matplotlib_support"]

    class unyt_arrayConverter(ConversionInterface):
        """Matplotlib interface for unyt_array"""

        _instance = None
        _labelstyle = "()"
        _axisnames = WeakKeyDictionary()

        # ensure that unyt_arrayConverter is a singleton
        def __new__(cls):
            if unyt_arrayConverter._instance is None:
                unyt_arrayConverter._instance = super().__new__(cls)
            return unyt_arrayConverter._instance

        # When matplotlib first encounters a type in its units.registry, it will
        # call default_units to obtain the units. Then it calls axisinfo to
        # customize the axis - in our case, just set the label. Then matplotlib calls
        # convert.

        @staticmethod
        def axisinfo(unit, axis):
            """Set the axis label based on unit

            Parameters
            ----------

            unit : Unit object, string, or tuple
                This parameter comes from unyt_arrayConverter.default_units() or from
                user code such as Axes.plot(), Axis.set_units(), etc. In user code, it
                is possible to convert the plotted units by specifing the new unit as
                a string, such as "ms", or as a tuple, such as ("J", "thermal")
                following the call signature of unyt_array.convert_to_units().
            axis : Axis object

            Returns
            -------

            AxisInfo object with the label formatted as in-line math latex
            """
            if isinstance(unit, tuple):
                unit = unit[0]
            unit_obj = unit if isinstance(unit, Unit) else Unit(unit)
            name = unyt_arrayConverter._axisnames.get(axis, "")
            if unit_obj.is_dimensionless:
                label = name
            else:
                name += " "
                unit_str = unit_obj.latex_representation()
                if unyt_arrayConverter._labelstyle == "[]":
                    label = name + "$\\left[" + unit_str + "\\right]$"
                elif unyt_arrayConverter._labelstyle == "/":
                    axsym = "$q_{\\rm" + axis.axis_name + "}$"
                    name = axsym if name == " " else name
                    if "/" in unit_str:
                        label = name + "$\\;/\\;\\left(" + unit_str + "\\right)$"
                    else:
                        label = name + "$\\;/\\;" + unit_str + "$"
                else:
                    label = name + "$\\left(" + unit_str + "\\right)$"
            return AxisInfo(label=label.strip())

        @staticmethod
        def default_units(x, axis):
            """Return the Unit object of the unyt_array x

            Parameters
            ----------

            x : unyt_array
            axis : Axis object

            Returns
            -------

            Unit object
            """
            # In the case where the first matplotlib command is setting limits,
            # x may be a tuple of length two (with the same units).
            if isinstance(x, tuple):
                name = getattr(x[0], "name", "")
                units = x[0].units
            else:
                name = getattr(x, "name", "")
                units = x.units

            # maintain a mapping between Axis and name since Axis does not point to
            # its underlying data and we want to propagate the name to the axis
            # label in the subsequent call to axisinfo
            unyt_arrayConverter._axisnames[axis] = name if name is not None else ""
            return units

        @staticmethod
        def convert(value, unit, axis):
            """Convert the units of value to unit

            Parameters
            ----------

            value : unyt_array, unyt_quantity, or sequence there of
            unit : Unit, string or tuple
                This parameter comes from unyt_arrayConverter.default_units() or from
                user code such as Axes.plot(), Axis.set_units(), etc. In user code, it
                is possible to convert the plotted units by specifing the new unit as
                a string, such as "ms", or as a tuple, such as ("J", "thermal")
                following the call signature of unyt_array.convert_to_units().
            axis : Axis object

            Returns
            -------

            unyt_array

            Raises
            ------

            UnitConversionError if unit does not have the same dimensions as value or
            if we don't know how to convert value.
            """
            converted_value = value
            if isinstance(unit, str) or isinstance(unit, Unit):
                unit = (unit,)
            if isinstance(value, (unyt_array, unyt_quantity)):
                converted_value = value.to(*unit)
            else:
                value_type = type(value)
                converted_value = []
                for obj in value:
                    converted_value.append(obj.to(*unit))
                converted_value = value_type(converted_value)
            return converted_value

[docs] class matplotlib_support: """Context manager for enabling the feature When used in a with statement, the feature is enabled during the context and then disabled after it exits. Parameters ---------- label_style : str One of the following set, ``{'()', '[]', '/'}``. These choices correspond to the following unit labels: * ``'()'`` -> ``'(unit)'`` * ``'[]'`` -> ``'[unit]'`` * ``'/'`` -> ``'q_x / unit'`` """ def __init__(self, label_style="()"): self._labelstyle = label_style unyt_arrayConverter._labelstyle = label_style self._enabled = False def __call__(self): self.__enter__() @property def label_style(self): """str: One of the following set, ``{'()', '[]', '/'}``. These choices correspond to the following unit labels: * ``'()'`` -> ``'(unit)'`` * ``'[]'`` -> ``'[unit]'`` * ``'/'`` -> ``'q_x / unit'`` """ return self._labelstyle @label_style.setter def label_style(self, label_style="()"): self._labelstyle = label_style unyt_arrayConverter._labelstyle = label_style def __enter__(self): registry[unyt_array] = unyt_arrayConverter() registry[unyt_quantity] = unyt_arrayConverter() self._enabled = True def __exit__(self, exc_type, exc_val, exc_tb): registry.pop(unyt_array) registry.pop(unyt_quantity) self._enabled = False
[docs] def enable(self): self.__enter__()
[docs] def disable(self): if self._enabled: self.__exit__(None, None, None)