Source code for dysh.util.core

"""
Core utility definitions, classes, and functions
"""

import hashlib
import importlib
import numbers
import sys
from collections.abc import Sequence
from itertools import zip_longest
from pathlib import Path
from typing import Union

import numpy as np
from astropy.time import Time
from astropy.units.quantity import Quantity
from IPython.display import HTML, display

ALL_CHANNELS = "all channels"


[docs] def select_from(key, value, df): """ Select data where key=value. Parameters ---------- key : str The key value (SDFITS column name) value : any The value to match df : `~pandas.DataFrame` The DataFrame to search Returns ------- df : `~pandas.DataFrame` The subselected DataFrame """ # nb this fails if value is None return df[(df[key] == value)]
[docs] def eliminate_flagged_rows(df, flag): """ Remove rows from an index (selection) where all channels have been flagged. Parameters ---------- df : `~pandas.DataFrame` The input dataframe from which flagged rows will be removed. flag : `~pandas.DataFrame` The flag dataframe. Should be the result of e.g. `~util.Flag.final` Returns ------- A data frame which is the input data frame with flagged rows removed. """ if len(flag) > 0: # in the final flagging selection any rows that have CHAN=ALL_CHANNELS # indicate that the entire row is flagged ff = flag[flag["CHAN"].isin([ALL_CHANNELS])] flagged_rows = set(ff["ROW"]) if len(flagged_rows) > 0: userows = list(set(df["ROW"]) - flagged_rows) if len(userows) > 0: return df[df["ROW"].isin(userows)] else: return df.iloc[0:0] # all rows removed return df
[docs] def indices_where_value_changes(colname, df): """ Find the `~pandas.DataFrame` indices where the value of the input column name changes. Parameters ---------- colname : str The column name to query. df : `~pandas.DataFrame` The DataFrame to search Returns ------- indices : ~numpy.ndarray The indices of the Dataframe where `colname` changes value. """ # @todo add option to return changing values along with index # e.g., [["A",0],["B",125],["C",246]] # This is some super panda kung-fu. # See https://stackoverflow.com/questions/48673046/get-index-where-value-changes-in-pandas-dataframe-column if colname not in df: raise KeyError(f"Column {colname} not in input DataFrame") # df.shift() shifts the index by one, so we are then comparing df[N] to df[N-1]. This gets us # a truth table of where values change. We filter on colname, then return a list of indices # where the value is true. Finally, we squeeze out the empty dimensions of the np array. ary = df.ne(df.shift()).filter(items=[colname]).apply(lambda x: x.index[x].tolist()).to_numpy() return np.squeeze(ary, axis=1)
[docs] def gbt_timestamp_to_time(timestamp): """Convert the GBT sdfits timestamp string format to an :class:`~astropy.time.Time` object. GBT SDFITS timestamps have the form YYYY_MM_DD_HH:MM:SS in UTC. Parameters ---------- timestamp : str or list-like The GBT format timestamp as described above. If str, a Time object containing a single time is returned. If list-like, a Time object containing multiple UTC times is returned. Returns ------- time : `~astropy.time.Time` The time object """ # convert to ISO FITS format YYYY-MM-DDTHH:MM:SS(.SSS) if isinstance(timestamp, str): t = timestamp.replace("_", "-", 2).replace("_", "T") else: t = [ts.replace("_", "-", 2).replace("_", "T") for ts in timestamp] return Time(t, scale="utc")
[docs] def to_mjd_list(time_val: Union[Time, float]) -> np.ndarray: """Convert an astropy Time, list of MJD, or single MJD to a list of MJD Parameters ---------- time_val : `~astropy.time.Time` or float or list of float The time value to convert. Returns ------- mjd : ~np.ndarray The Modified Julian Day values in an array. (or None if `time_val` was None) """ if time_val is None: return None # check for Time first since it is also a Sequence if isinstance(time_val, Time): if time_val.isscalar: return np.array([time_val.mjd]) else: return time_val.mjd if isinstance(time_val, (Sequence, np.ndarray)) and not isinstance(time_val, str): # str is also a Sequence return time_val if isinstance(time_val, numbers.Number): return np.array([time_val]) else: raise ValueError(f"Unrecognized type for time value: {type(time_val)}")
[docs] def to_quantity_list(q: Union[Quantity, Sequence]) -> Quantity: # if given quanity or [quanity], return [quanity.value]*quantity.units # handle quantities first if isinstance(q, Quantity): if q.isscalar: return [q.value] * q.unit else: return q # now handle lists of quantities if isinstance(q, Sequence): if len(set([x.unit for x in q])) != 1: raise ValueError("Units must all be the same in input list") return [x.value for x in q] * q[0].unit
[docs] def generate_tag(values, hashlen, add_time=True): """ Generate a unique tag based on input values. A hash object is created from the input values using SHA256, and a hex representation is created. The first `hashlen` characters of the hex string are returned. Parameters ---------- values : array-like The values to use in creating the hash object hashlen : int, optional The length of the returned hash string. add_time: bool Add the time of the call to the values for hash generation. Returns ------- tag : str The hash string """ if add_time: values.append(Time.now().value) data = "".join(map(str, values)) hash_object = hashlib.sha256(data.encode()) unique_id = hash_object.hexdigest() return unique_id[0:hashlen]
[docs] def consecutive(data, stepsize=1): """Returns the indices of elements in `data` separated by less than stepsize separated into groups. Parameters ---------- data : array Array with values to split. stepsize : int Maximum separation between elements of `data` to be considered a single group. Returns ------- groups : `~numpy.ndarray` Array with values of `data` separated into groups. """ return np.split(data, np.where(np.diff(data) >= stepsize)[0] + 1)
[docs] def sq_weighted_avg(a, axis=0, weights=None): # @todo make a generic moment or use scipy.stats.moment r"""Compute the mean square weighted average of an array (2nd moment). :math:`v = \sqrt{\frac{\sum_i{w_i~a_i^{2}}}{\sum_i{w_i}}}` Parameters ---------- a : `~numpy.ndarray` The data to average axis : int The axis over which to average the data. Default: 0 weights : `~numpy.ndarray` or None The weights to use in averaging. The weights array must be the length of the axis over which the average is taken. Default: `None` will use equal weights. Returns ------- average : `~numpy.ndarray` The average along the input axis """ if weights is None: w = np.ones_like(a) else: w = weights v = np.sqrt(np.average(a * a, axis=axis, weights=w)) return v
[docs] def get_project_root() -> Path: """ Returns the project root directory. """ return importlib.resources.files("dysh")
[docs] def get_project_testdata() -> Path: """ Returns the project testdata directory """ return get_project_root().parent.parent / "testdata"
[docs] def get_project_data() -> Path: """ Returns the directory where dysh configuration files are kept. Returns ------- Path The project configuration directory. """ return get_project_root() / "data"
[docs] def get_size(obj, seen=None): """Recursively finds size of objects. See https://goshippo.com/blog/measure-real-size-any-python-object/ """ size = sys.getsizeof(obj) if seen is None: seen = set() obj_id = id(obj) if obj_id in seen: return 0 # Important mark as seen *before* entering recursion to gracefully handle # self-referential objects seen.add(obj_id) if isinstance(obj, dict): size += sum([get_size(v, seen) for v in obj.values()]) size += sum([get_size(k, seen) for k in obj.keys()]) elif hasattr(obj, "__dict__"): size += get_size(obj.__dict__, seen) elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)): size += sum([get_size(i, seen) for i in obj]) return size
[docs] def minimum_string_match(s, valid_strings, casefold=False): """ return the valid string from a list, given a minimum string input Example: minimum_string_match('a',['alpha','beta','gamma']) returns: 'alpha' Parameters ---------- s : string string to use for minimum match valid_strings : list of strings list of full strings to minimum match on. casefold: bool If True, do a case insensitive match Returns ------- string matched string, if one is found. An exact match will also count as a match, even if others are present with longer match. Otherwise "None" is returned. """ n = len(valid_strings) if casefold: vsfold = [a.casefold() for a in valid_strings] s = s.casefold() else: vsfold = valid_strings m = [] for i in range(n): if vsfold[i].find(s) == 0: m.append(i) if len(m) >= 1: return valid_strings[m[0]] return None
[docs] def minimum_list_match(strings, valid_strings, casefold=False): """ Return the list of valid strings given a list of minimum string inputs. Parameters ---------- strings : str or list of str The strings to compare for minimum match valid_strings : list of str list of full strings to min match on. casefold: bool If True, do a case insensitive match Returns ------- list List of all minimum matches or None if no matches found """ valid = [] # if user passes in a string instead of a list, it should act like minimum_string_match # Note: strings=list(strings) is not the same as [strings]! if isinstance(strings, str): strings = [strings] for s in strings: p = minimum_string_match(s, valid_strings, casefold) if p is not None: valid.append(p) if len(valid) == 0: return None else: return valid
[docs] def uniq(seq): """Remove duplicates from a list while preserving order. from http://stackoverflow.com/questions/480214/how-do-you-remove-duplicates-from-a-list-in-python-whilst-preserving-order """ seen = set() seen_add = seen.add return [x for x in seq if x not in seen and not seen_add(x)]
[docs] def keycase(d, case="upper"): """ Change the case of dictionary keys Parameters ---------- d : dict The input dictionary case : str, one of 'upper', 'lower' Case to change keys to The default is "upper". Returns ------- newDict : dict A copy of the dictionary with keys changed according to `case` """ if case == "upper": newDict = {k.upper(): v for k, v in d.items()} elif case == "lower": newDict = {k.lower(): v for k, v in d.items()} return newDict
# Example of logging a function call # @log_function_call(log_level="debug")
[docs] def powerof2(number): """ Computes the closest power of 2 for a given `number`. Parameters ---------- number : float number to determine the closest power of 2. Returns ------- pow2 : int the closest power of 2. """ return round(np.log10(number) / np.log10(2.0))
[docs] def convert_array_to_mask(a, length, value=True): """ This method interprets a simple or compound array and returns a numpy mask of length `length`. Single arrays/tuples will be treated as element index lists; nested arrays will be treated as *inclusive* ranges, for instance: `` # mask elements 1 and 10 convert_array_to_mask([1,10]) # mask elements 1 thru 10 inclusive convert_array_to_mask([[1,10]]) # mask ranges 1 thru 10 and 47 thru 56 inclusive, and element 75 convert_array_to_mask([[1,10], [47,56], 75)]) # tuples also work, though can be harder for a human to read convert_array_to_mask(((1,10), [47,56], 75)) `` Parameters ---------- a : number or array-like The length : int The length of the mask to return, e.g. the number of channels in a spectrum. value : bool The value to fill the mask with. True to mask data, False to unmask. Returns ------- mask : ~np.ndarray A numpy array where the mask is True according to the rules above. """ if a == ALL_CHANNELS: return np.full(length, value) mask = np.full(length, False) for v in a: if isinstance(v, (tuple, list, np.ndarray)) and len(v) == 2: # If there are just two numbers, interpret is as an inclusive range mask[v[0] : v[1] + 1] = value else: mask[v] = value return mask
[docs] def abbreviate_to(length, value, squeeze=True): """ Abbreviate a value for display in limited space. The abbreviated value will have initial characters, ellipsis, and final characters, e.g. '[(a,b),(c,d)...(w,x),(y,z)]'. Parameters ---------- length : int Maximum string length. value : any The value to be abbreviated. squeeze : bool, optional Squeeze blanks. If True, replace ", " (comma space) with "," (comma). The default is True. Returns ------- strv : str Abbreviated string representation of the input value """ strv = str(value) if squeeze: strv = strv.replace(", ", ",") if len(strv) > length: bc = int(length / 2) - 1 ec = bc - 1 strv = strv[0:bc] + "..." + strv[-ec:] return strv
[docs] def merge_ranges(ranges): """ Merge overlapping and adjacent ranges and yield the merged ranges in order. The argument must be an iterable of pairs (start, stop). Taken from: https://codereview.stackexchange.com/a/21333 Parameters ---------- ranges : iterable Pairs of (start, stop) ranges. Yields ------ iterable Merged ranges. Examples -------- >>> list(merge_ranges([(5,7), (3,5), (-1,3)])) [(-1, 7)] >>> list(merge_ranges([(5,6), (3,4), (1,2)])) [(1, 2), (3, 4), (5, 6)] >>> list(merge_ranges([])) [] """ ranges = iter(sorted(ranges)) try: current_start, current_stop = next(ranges) except StopIteration: return for start, stop in ranges: if start > current_stop: # Gap between segments: output current segment and start a new one. yield current_start, current_stop current_start, current_stop = start, stop else: # Segments adjacent or overlapping: merge. current_stop = max(current_stop, stop) yield current_start, current_stop
[docs] def grouper(iterable, n, *, incomplete="fill", fillvalue=None): "Collect data into non-overlapping fixed-length chunks or blocks." # grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx # grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError # grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF iterators = [iter(iterable)] * n match incomplete: case "fill": return zip_longest(*iterators, fillvalue=fillvalue) case "strict": return zip(*iterators, strict=True) case "ignore": return zip(*iterators, strict=False) case _: raise ValueError("Expected fill, strict, or ignore")
[docs] def in_notebook() -> bool: """ Check if the code is being run inside a notebook. """ try: from IPython import get_ipython if "IPKernelApp" not in get_ipython().config: # pragma: no cover return False except ImportError: return False except AttributeError: return False return True
[docs] def show_dataframe(df, show_index=False, max_rows=None, max_cols=None): """ Function to show a `~pandas.DataFrame` in IPython or Jupyter. Parameters ---------- df : `~pandas.DataFrame` The `~pandas.DataFrame` to be shown. show_index : bool Show the index of the `~pandas.DataFrame`. max_rows : int or None Maximum number of rows to display. max_cols : int or None Maximum number of columns to display. """ kwargs = {"max_rows": max_rows, "max_cols": max_cols, "index": show_index} if in_notebook(): display(HTML(df.to_html(**kwargs))) else: print(df.to_string(**kwargs))