"""
A registry for units that can be added to and modified.
"""
import copy
import json
from functools import lru_cache
from hashlib import md5
from sympy import sympify
from unyt import dimensions as unyt_dims
from unyt._unit_lookup_table import default_unit_symbol_lut, unit_prefixes
from unyt.exceptions import SymbolNotFoundError, UnitParseError
from unyt.unit_systems import _split_prefix, mks_unit_system, unit_system_registry
def _sanitize_unit_system(unit_system, obj):
if unit_system is None:
try:
unit_system = obj.units.registry.unit_system
except AttributeError:
unit_system = mks_unit_system
if hasattr(unit_system, "name"):
return unit_system_registry[unit_system.name]
elif hasattr(unit_system, "unit_registry"):
unit_system = unit_system.unit_registry.unit_system_id
elif unit_system == "code":
unit_system = obj.units.registry.unit_system_id
return unit_system_registry[str(unit_system)]
[docs]
@lru_cache(maxsize=128, typed=False)
def cached_sympify(u):
"""
Successive loads of unit systems produce the same calls to sympify
in UnitRegistry.from_json. Even within a single load, this is a
net improvement because there will often be a few cache hits
"""
return sympify(u, locals=vars(unyt_dims))
[docs]
class UnitRegistry:
"""A registry for unit symbols"""
_unit_system_id = None
def __init__(self, add_default_symbols=True, lut=None, unit_system=None):
self._unit_object_cache = {}
if lut:
self.lut = lut
else:
self.lut = {}
self.unit_system = _sanitize_unit_system(unit_system, None)
if add_default_symbols:
self.lut.update(default_unit_symbol_lut)
def __getitem__(self, key):
try:
ret = self.lut[str(key)]
except KeyError:
try:
_lookup_unit_symbol(str(key), self.lut)
ret = self.lut[str(key)]
except UnitParseError:
raise SymbolNotFoundError(
f"The symbol '{key}' does not exist in this registry."
)
return ret
def __contains__(self, item):
if str(item) in self.lut:
return True
try:
_lookup_unit_symbol(str(item), self.lut)
return True
except UnitParseError:
return False
@property
def unit_system_id(self):
"""
This is a unique identifier for the unit registry created
from a FNV hash. It is needed to register a dataset's code
unit system in the unit system registry.
"""
if self._unit_system_id is None:
hash_data = bytearray()
for k, v in sorted(self.lut.items()):
hash_data.extend(k.encode("utf8"))
hash_data.extend(repr(v).encode("utf8"))
m = md5()
m.update(hash_data)
self._unit_system_id = str(m.hexdigest())
return self._unit_system_id
@property
def prefixable_units(self):
return [u for u in self.lut if self.lut[u][4]]
[docs]
def add(
self,
symbol,
base_value,
dimensions,
tex_repr=None,
offset=None,
prefixable=False,
):
"""
Add a symbol to this registry.
Parameters
----------
symbol : str
The name of the unit
base_value : float
The scaling from the units value to the equivalent SI unit
with the same dimensions
dimensions : expr
The dimensions of the unit
tex_repr : str, optional
The LaTeX representation of the unit. If not provided a LaTeX
representation is automatically generated from the name of
the unit.
offset : float, optional
If set, the zero-point offset to apply to the unit to convert
to SI. This is mostly used for units like Farhenheit and
Celcius that are not defined on an absolute scale.
prefixable : bool
If True, then SI-prefix versions of the unit will be created
along with the unit itself.
"""
from unyt.unit_object import _validate_dimensions
self._unit_system_id = None
# Validate
if not isinstance(base_value, float):
raise UnitParseError(
f"base_value ({base_value}) must be a float, got a {type(base_value)}."
)
if offset is not None:
if not isinstance(offset, float):
raise UnitParseError(
f"offset value ({offset}) must be a float, got a {type(offset)}."
)
else:
offset = 0.0
_validate_dimensions(dimensions)
if tex_repr is None:
# make educated guess that will look nice in most cases
tex_repr = r"\rm{" + symbol.replace("_", r"\ ") + "}"
# Add to lut
self.lut[symbol] = (base_value, dimensions, offset, tex_repr, prefixable)
[docs]
def remove(self, symbol):
"""
Remove the entry for the unit matching `symbol`.
Parameters
----------
symbol : str
The name of the unit symbol to remove from the registry.
"""
self._unit_system_id = None
if symbol not in self.lut:
raise SymbolNotFoundError(
"Tried to remove the symbol '%s', but it does not exist "
"in this registry." % symbol
)
del self.lut[symbol]
if symbol in self._unit_object_cache:
del self._unit_object_cache[symbol]
[docs]
def modify(self, symbol, base_value):
"""
Change the base value of a unit symbol. Useful for adjusting code
units after parsing parameters.
Parameters
----------
symbol : str
The name of the symbol to modify
base_value : float
The new base_value for the symbol.
"""
self._unit_system_id = None
if symbol not in self.lut:
raise SymbolNotFoundError(
"Tried to modify the symbol '%s', but it does not exist "
"in this registry." % symbol
)
if hasattr(base_value, "in_base"):
new_dimensions = base_value.units.dimensions
base_value = base_value.in_base("mks")
base_value = base_value.value
else:
new_dimensions = self.lut[symbol][1]
self.lut[symbol] = (float(base_value), new_dimensions) + self.lut[symbol][2:]
if symbol in self._unit_object_cache:
del self._unit_object_cache[symbol]
[docs]
def keys(self):
"""
Print out the units contained in the lookup table.
"""
return self.lut.keys()
[docs]
def to_json(self):
"""
Returns a json-serialized version of the unit registry
"""
sanitized_lut = {}
for k, v in self.lut.items():
san_v = list(v)
repr_dims = str(v[1])
san_v[1] = repr_dims
sanitized_lut[k] = tuple(san_v)
return json.dumps(sanitized_lut)
[docs]
@classmethod
def from_json(cls, json_text):
"""
Returns a UnitRegistry object from a json-serialized unit registry
Parameters
----------
json_text : str
A string containing a json represention of a UnitRegistry
"""
data = json.loads(json_text)
lut = _correct_old_unit_registry(data, sympify=True)
return cls(lut=lut, add_default_symbols=False)
[docs]
def list_same_dimensions(self, unit_object):
"""
Return a list of base unit names that this registry knows about that
are of equivalent dimensions to *unit_object*.
"""
equiv = [k for k, v in self.lut.items() if v[1] is unit_object.dimensions]
equiv = sorted(set(equiv))
return equiv
def __deepcopy__(self, memodict=None):
lut = copy.deepcopy(self.lut)
return type(self)(lut=lut)
class _NonModifiableUnitRegistry(UnitRegistry):
"""The class of the default unit registry"""
def modify(self, symbol, base_value):
raise TypeError("Units from unyt's default registry cannot be modified.")
def remove(self, symbol):
raise TypeError("Units from unyt's default registry cannot be removed.")
#: The default unit registry
default_unit_registry = _NonModifiableUnitRegistry()
def _lookup_unit_symbol(symbol_str, unit_symbol_lut):
"""
Searches for the unit data tuple corresponding to the given symbol.
Parameters
----------
symbol_str : str
The unit symbol to look up.
unit_symbol_lut : dict
Dictionary with symbols as keys and unit data tuples as values.
"""
if symbol_str in unit_symbol_lut:
# lookup successful, return the tuple directly
return unit_symbol_lut[symbol_str]
# could still be a known symbol with a prefix
prefix, symbol_wo_prefix = _split_prefix(symbol_str, unit_symbol_lut)
if prefix:
# lookup successful, it's a symbol with a prefix
unit_data = unit_symbol_lut[symbol_wo_prefix]
prefix_value = unit_prefixes[prefix][0]
# Need to add some special handling for comoving units
# this is fine for now, but it wouldn't work for a general
# unit that has an arbitrary LaTeX representation
if symbol_wo_prefix != "cm" and symbol_wo_prefix.endswith("cm"):
sub_symbol_wo_prefix = symbol_wo_prefix[:-2]
sub_symbol_str = symbol_str[:-2]
else:
sub_symbol_wo_prefix = symbol_wo_prefix
sub_symbol_str = symbol_str
latex_repr = unit_data[3].replace(
"{" + sub_symbol_wo_prefix + "}", "{" + sub_symbol_str + "}"
)
# Leave offset and dimensions the same, but adjust scale factor and
# LaTeX representation
ret = (
unit_data[0] * prefix_value,
unit_data[1],
unit_data[2],
latex_repr,
False,
)
unit_symbol_lut[symbol_str] = ret
return ret
# no dice
raise UnitParseError(
f"Could not find unit symbol '{symbol_str}' in the provided symbols."
)
def _correct_old_unit_registry(data, sympify=False):
lut = {}
for k, v in data.items():
unsan_v = list(v)
if sympify:
unsan_v[1] = cached_sympify(v[1])
if len(unsan_v) == 4:
# old unit registry so we need to add SI-prefixability to the registry
# entry, correct the base_value to be in MKS units, and swap dimensions to
# use unyt's dimension singletons
# add SI-prefixability to LUT entry
if k in default_unit_symbol_lut:
unsan_v.append(default_unit_symbol_lut[k][4])
else:
unsan_v.append(False)
dims = unsan_v[1]
for dim_factor in dims.as_ordered_factors():
dim, power = dim_factor.as_base_exp()
# Swap dimensions in the LUT entry to use unyt's dimension singletons
for base_dim in unyt_dims.base_dimensions:
# If they're *equal* but not *identical*, swap them
if base_dim == dim and base_dim is not dim:
if power != 1:
unsan_v[1] /= dim**power
unsan_v[1] *= base_dim**power
else:
# need a special case for power == 1 because id(symbol ** 1)
# is not necessarily the same as id(symbol)
unsan_v[1] /= dim
unsan_v[1] *= base_dim
break
# correct base value to be in MKS units
if dim == unyt_dims.mass:
unsan_v[0] /= 1000 ** float(power)
if dim == unyt_dims.length:
unsan_v[0] /= 100 ** float(power)
lut[k] = tuple(unsan_v)
for k in default_unit_symbol_lut:
if k not in lut:
lut[k] = default_unit_symbol_lut[k]
return lut