Source code for alogos._utilities.argument_processing

from collections.abc import Callable as _Callable
from numbers import Number as _Number

from .operating_system import NEWLINE as _NEWLINE


[docs]def check_arg(arg_name, arg_val, types=None, vals=None, conv=None): """Check an argument by inspecting its type and value.""" # Check if the argument type is valid if types: if not any(isinstance(arg_val, typ) for typ in types): arg_val_type = type(arg_val).__name__ type_names = ", ".join(str(typ) for typ in types) message = ( 'Argument "{name}" got a value with an invalid type.{nl}' "Given value: {val}{nl}" "Given type: {typ}{nl}" "Possible types: {options}".format( name=arg_name, val=repr(arg_val), typ=arg_val_type, options=type_names, nl=_NEWLINE, ) ) raise TypeError(message) from None # Check if the argument value is valid if vals: if arg_val not in vals: val_names = ", ".join(repr(val) for val in vals) message = ( 'Argument "{name}" got an invalid value.{nl}' "Given value: {val}{nl}" "Possible values: {options}".format( name=arg_name, val=repr(arg_val), options=val_names, nl=_NEWLINE ) ) raise ValueError(message) from None # Convert the argument value if conv: if arg_val in conv: arg_val = conv[arg_val] return arg_val
[docs]def str_arg(arg_name, arg_val, default=None, vals=None, to_lower=False): """Check a string argument.""" if to_lower: try: arg_val = arg_val.lower() except AttributeError: pass types = (str,) if default is None else (str, type(None)) return check_arg(arg_name, arg_val, types=types, vals=vals, conv={None: default})
[docs]def int_arg( arg_name, arg_val, default=None, vals=None, min_incl=None, max_incl=None, allow_none=False, ): """Check an integer argument.""" if default is not None or allow_none: types = (int, type(None)) else: types = (int,) value = check_arg(arg_name, arg_val, types=types, vals=vals, conv={None: default}) if value is not None: if min_incl is not None and value < min_incl: message = ( 'Argument "{name}" got an invalid value.{nl}' "Given value: {val}{nl}" "Lowest possible value: {minval}".format( name=arg_name, val=arg_val, minval=min_incl, nl=_NEWLINE ) ) raise ValueError(message) from None if max_incl is not None and value > max_incl: message = ( 'Argument "{name}" got an invalid value.{nl}' "Given value: {val}{nl}" "Highest possible value: {maxval}".format( name=arg_name, val=arg_val, maxval=max_incl, nl=_NEWLINE ) ) raise ValueError(message) from None return value
[docs]def num_arg(arg_name, arg_val, default=None, vals=None, min_incl=None, max_incl=None): """Check a numerical argument.""" types = (_Number,) if default is None else (_Number, type(None)) value = check_arg(arg_name, arg_val, types=types, vals=vals, conv={None: default}) if value is not None: if min_incl is not None and value < min_incl: message = ( 'Argument "{name}" got an invalid value.{nl}' "Given value: {val}{nl}" "Lowest possible value: {minval}".format( name=arg_name, val=arg_val, minval=min_incl, nl=_NEWLINE ) ) raise ValueError(message) from None if max_incl is not None and value > max_incl: message = ( 'Argument "{name}" got an invalid value.{nl}' "Given value: {val}{nl}" "Highest possible value: {maxval}".format( name=arg_name, val=arg_val, maxval=max_incl, nl=_NEWLINE ) ) raise ValueError(message) from None return value
[docs]def bool_arg(arg_name, arg_val, default=None, vals=None): """Check a boolean argument.""" types = (bool,) if default is None else (bool, type(None)) return check_arg(arg_name, arg_val, types=types, vals=vals, conv={None: default})
[docs]def callable_arg(arg_name, arg_val, default=None, vals=None): """Check a callable argument.""" types = (_Callable,) if default is None else (_Callable, type(None)) return check_arg(arg_name, arg_val, types=types, vals=vals, conv={None: default})
[docs]def logical_xor(var1, var2): """Compute logical XOR between two inputs.""" return bool(var1) ^ bool(var2)
[docs]def ensure_file_extension(filepath, ending): """Ensure a filepath ends with a certain extension.""" if not ending.startswith("."): ending = "." + str(ending) if not filepath.endswith(ending): filepath += ending return filepath
[docs]def ensure_no_file_extension(filepath, ending): """Ensure a filepath ends without a certain extension.""" if not ending.startswith("."): ending = "." + str(ending) if filepath.endswith(ending): filepath = filepath[: -len(ending)] return filepath