Source code for dysh.util.core

"""
Core utility definitions, classes, and functions
"""

import hashlib
import importlib
import numbers
import sys
from collections.abc import Sequence
from itertools import zip_longest
from pathlib import Path

import numpy as np
import pandas as pd
from astropy.table import Table
from astropy.time import Time
from astropy.units.quantity import Quantity
from IPython.display import HTML, display

ALL_CHANNELS = "all channels"


[docs] def select_from(key, value, df): """ Select data where key=value. Parameters ---------- key : str The key value (SDFITS column name) value : any The value to match df : `~pandas.DataFrame` The DataFrame to search Returns ------- df : `~pandas.DataFrame` The subselected DataFrame """ # nb this fails if value is None return df[(df[key] == value)]
[docs] def eliminate_flagged_rows(df, flag): """ Remove rows from an index (selection) where all channels have been flagged. Parameters ---------- df : `~pandas.DataFrame` The input dataframe from which flagged rows will be removed. flag : `~pandas.DataFrame` The flag dataframe. Should be the result of e.g. `~util.Flag.final` Returns ------- A data frame which is the input data frame with flagged rows removed. """ if len(flag) > 0: # in the final flagging selection any rows that have CHAN=ALL_CHANNELS # indicate that the entire row is flagged ff = flag[flag["CHAN"].isin([ALL_CHANNELS])] flagged_rows = set(ff["ROW"]) if len(flagged_rows) > 0: userows = list(set(df["ROW"]) - flagged_rows) if len(userows) > 0: return df[df["ROW"].isin(userows)] else: return df.iloc[0:0] # all rows removed return df
[docs] def indices_where_value_changes(colname, df): """ Find the `~pandas.DataFrame` indices where the value of the input column name changes. Parameters ---------- colname : str The column name to query. df : `~pandas.DataFrame` The DataFrame to search Returns ------- indices : ~numpy.ndarray The indices of the Dataframe where `colname` changes value. """ # @todo add option to return changing values along with index # e.g., [["A",0],["B",125],["C",246]] # This is some super panda kung-fu. # See https://stackoverflow.com/questions/48673046/get-index-where-value-changes-in-pandas-dataframe-column if colname not in df: raise KeyError(f"Column {colname} not in input DataFrame") # df.shift() shifts the index by one, so we are then comparing df[N] to df[N-1]. This gets us # a truth table of where values change. We filter on colname, then return a list of indices # where the value is true. Finally, we squeeze out the empty dimensions of the np array. ary = df.ne(df.shift()).filter(items=[colname]).apply(lambda x: x.index[x].tolist()).to_numpy() return np.squeeze(ary, axis=1)
[docs] def gbt_timestamp_to_time(timestamp): """Convert the GBT sdfits timestamp string format to an :class:`~astropy.time.Time` object. GBT SDFITS timestamps have the form YYYY_MM_DD_HH:MM:SS in UTC. Parameters ---------- timestamp : str or list-like The GBT format timestamp as described above. If str, a Time object containing a single time is returned. If list-like, a Time object containing multiple UTC times is returned. Returns ------- time : `~astropy.time.Time` The time object """ # convert to ISO FITS format YYYY-MM-DDTHH:MM:SS(.SSS) if isinstance(timestamp, str): t = timestamp.replace("_", "-", 2).replace("_", "T") else: t = [ts.replace("_", "-", 2).replace("_", "T") for ts in timestamp] return Time(t, scale="utc")
[docs] def to_mjd_list(time_val: Time | float) -> np.ndarray: """Convert an astropy Time, list of MJD, or single MJD to a list of MJD Parameters ---------- time_val : `~astropy.time.Time` or float or list of float The time value to convert. Returns ------- mjd : ~np.ndarray The Modified Julian Day values in an array. (or None if `time_val` was None) """ if time_val is None: return None # check for Time first since it is also a Sequence if isinstance(time_val, Time): if time_val.isscalar: return np.array([time_val.mjd]) else: return time_val.mjd if isinstance(time_val, (Sequence, np.ndarray)) and not isinstance(time_val, str): # str is also a Sequence return time_val if isinstance(time_val, numbers.Number): return np.array([time_val]) else: raise ValueError(f"Unrecognized type for time value: {type(time_val)}")
[docs] def to_quantity_list(q: Quantity | Sequence) -> Quantity: # if given quanity or [quanity], return [quanity.value]*quantity.units # handle quantities first if isinstance(q, Quantity): if q.isscalar: return [q.value] * q.unit else: return q # now handle lists of quantities if isinstance(q, Sequence): if len(set([x.unit for x in q])) != 1: raise ValueError("Units must all be the same in input list") return [x.value for x in q] * q[0].unit
[docs] def generate_tag(values, hashlen, add_time=True): """ Generate a unique tag based on input values. A hash object is created from the input values using SHA256, and a hex representation is created. The first `hashlen` characters of the hex string are returned. Parameters ---------- values : array-like The values to use in creating the hash object. hashlen : int, optional The length of the returned hash string. add_time : bool Add the time of the call to the values for hash generation. Returns ------- tag : str The hash string. """ if add_time: values.append(Time.now().value) data = "".join(map(str, values)) hash_object = hashlib.sha256(data.encode()) unique_id = hash_object.hexdigest() return unique_id[0:hashlen]
[docs] def consecutive(data, stepsize=1): """ Returns the indices of elements in `data` separated by less than stepsize separated into groups. Parameters ---------- data : array Array with values to split. stepsize : int Maximum separation between elements of `data` to be considered a single group. Returns ------- groups : `~numpy.ndarray` Array with values of `data` separated into groups. """ return np.split(data, np.where(np.diff(data) >= stepsize)[0] + 1)
[docs] def sq_weighted_avg(a, axis=0, weights=None): # @todo make a generic moment or use scipy.stats.moment r"""Compute the mean square weighted average of an array (2nd moment). :math:`v = \sqrt{\frac{\sum_i{w_i~a_i^{2}}}{\sum_i{w_i}}}` Parameters ---------- a : `~numpy.ndarray` The data to average axis : int The axis over which to average the data. Default: 0 weights : `~numpy.ndarray` or None The weights to use in averaging. The weights array must be the length of the axis over which the average is taken. Default: `None` will use equal weights. Returns ------- average : `~numpy.ndarray` The average along the input axis """ if weights is None: w = np.ones_like(a) else: w = weights v = np.sqrt(np.average(a * a, axis=axis, weights=w)) return v
[docs] def get_project_root() -> Path: """ Returns the project root directory. """ return importlib.resources.files("dysh")
[docs] def get_project_testdata() -> Path: """ Returns the project testdata directory """ return get_project_root().parent.parent / "testdata"
[docs] def get_project_data() -> Path: """ Returns the directory where dysh configuration files are kept. Returns ------- Path The project configuration directory. """ return get_project_root() / "data"
[docs] def get_size(obj, seen=None): """Recursively finds size of objects. See https://goshippo.com/blog/measure-real-size-any-python-object/ """ size = sys.getsizeof(obj) if seen is None: seen = set() obj_id = id(obj) if obj_id in seen: return 0 # Important mark as seen *before* entering recursion to gracefully handle # self-referential objects seen.add(obj_id) if isinstance(obj, dict): size += sum([get_size(v, seen) for v in obj.values()]) size += sum([get_size(k, seen) for k in obj.keys()]) elif hasattr(obj, "__dict__"): size += get_size(obj.__dict__, seen) elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)): size += sum([get_size(i, seen) for i in obj]) return size
[docs] def minimum_string_match(s, valid_strings, casefold=False): """ return the valid string from a list, given a minimum string input Example: minimum_string_match('a',['alpha','beta','gamma']) returns: 'alpha' Parameters ---------- s : string string to use for minimum match valid_strings : list of strings list of full strings to minimum match on. casefold: bool If True, do a case insensitive match Returns ------- string matched string, if one is found. An exact match will also count as a match, even if others are present with longer match. Otherwise "None" is returned. """ n = len(valid_strings) if casefold: vsfold = [a.casefold() for a in valid_strings] s = s.casefold() else: vsfold = valid_strings m = [] for i in range(n): if vsfold[i].find(s) == 0: m.append(i) if len(m) >= 1: return valid_strings[m[0]] return None
[docs] def minimum_list_match(strings, valid_strings, casefold=False): """ Return the list of valid strings given a list of minimum string inputs. Parameters ---------- strings : str or list of str The strings to compare for minimum match valid_strings : list of str list of full strings to min match on. casefold: bool If True, do a case insensitive match Returns ------- list List of all minimum matches or None if no matches found """ valid = [] # if user passes in a string instead of a list, it should act like minimum_string_match # Note: strings=list(strings) is not the same as [strings]! if isinstance(strings, str): strings = [strings] for s in strings: p = minimum_string_match(s, valid_strings, casefold) if p is not None: valid.append(p) if len(valid) == 0: return None else: return valid
[docs] def uniq(seq): """Remove duplicates from a list while preserving order. from http://stackoverflow.com/questions/480214/how-do-you-remove-duplicates-from-a-list-in-python-whilst-preserving-order """ seen = set() seen_add = seen.add return [x for x in seq if x not in seen and not seen_add(x)]
[docs] def keycase(d, case="upper"): """ Change the case of dictionary keys Parameters ---------- d : dict The input dictionary case : str, one of 'upper', 'lower' Case to change keys to The default is "upper". Returns ------- newDict : dict A copy of the dictionary with keys changed according to `case` """ if case == "upper": newDict = {k.upper(): v for k, v in d.items()} elif case == "lower": newDict = {k.lower(): v for k, v in d.items()} return newDict
# Example of logging a function call # @log_function_call(log_level="debug")
[docs] def powerof2(number): """ Computes the closest power of 2 for a given `number`. Parameters ---------- number : float number to determine the closest power of 2. Returns ------- pow2 : int the closest power of 2. """ return round(np.log10(number) / np.log10(2.0))
[docs] def convert_array_to_mask(a, length, value=True): """ This method interprets a simple or compound array and returns a numpy mask of length `length`. Single arrays/tuples will be treated as element index lists; nested arrays will be treated as *inclusive* ranges. Parameters ---------- a : number or array-like The channels to mask. See the examples for use. length : int The length of the mask to return, e.g. the number of channels in a spectrum. value : bool The value to fill the mask with. True to mask data, False to unmask. Returns ------- mask : `~numpy.ndarray` A numpy array where the mask is `value`. Examples -------- Mask elements 1 and 10. >>> convert_array_to_mask([1,10]) Mask elements 1 thru 10 inclusive. >>> convert_array_to_mask([[1,10]]) Mask ranges 1 thru 10 and 47 thru 56 inclusive, and element 75. >>> convert_array_to_mask([[1,10], [47,56], 75)]) Tuples also work. To do the same as above. >>> convert_array_to_mask(((1,10), [47,56], 75)) """ if str(a) == ALL_CHANNELS: return np.full(length, value) mask = np.full(length, False) for v in a: if isinstance(v, (tuple, list, np.ndarray)) and len(v) == 2: # If there are just two numbers, interpret is as an inclusive range mask[v[0] : v[1] + 1] = value else: mask[v] = value return mask
[docs] def abbreviate_to(length, value, squeeze=True) -> str: """ Abbreviate a value for display in limited space. The abbreviated value will have initial characters, ellipsis, and final characters, e.g. '[(a,b),(c,d),...,(w,x),(y,z)]'. Parameters ---------- length : int Maximum string length. value : any The value to be abbreviated. squeeze : bool, optional Squeeze blanks. If True, replace ", " (comma space) with "," (comma). The default is True. Returns ------- strv : str Abbreviated string representation of the input value. """ strv = str(value) sep = ", " if squeeze: strv = strv.replace(", ", ",") sep = "," if len(strv) > length: try: bc = strv.rindex(sep, 0, length // 2 - 1) except ValueError: bc = strv.index(sep) try: ec = strv[-length // 2 + 1 :].index(sep) eci = -length // 2 + 1 except ValueError: ec = strv.rindex(sep) eci = None strv = strv[:bc] + sep + "..." + strv[eci:][ec:] return strv
[docs] def merge_ranges(ranges): """ Merge overlapping and adjacent ranges and yield the merged ranges in order. The argument must be an iterable of pairs (start, stop). Taken from: https://codereview.stackexchange.com/a/21333 Parameters ---------- ranges : iterable Pairs of (start, stop) ranges. Yields ------ iterable Merged ranges. Examples -------- >>> list(merge_ranges([(5,7), (3,5), (-1,3)])) [(-1, 7)] >>> list(merge_ranges([(5,6), (3,4), (1,2)])) [(1, 2), (3, 4), (5, 6)] >>> list(merge_ranges([])) [] """ ranges = iter(sorted(ranges)) try: current_start, current_stop = next(ranges) except StopIteration: return for start, stop in ranges: if start > current_stop: # Gap between segments: output current segment and start a new one. yield current_start, current_stop current_start, current_stop = start, stop else: # Segments adjacent or overlapping: merge. current_stop = max(current_stop, stop) yield current_start, current_stop
[docs] def grouper(iterable, n, *, incomplete="fill", fillvalue=None): "Collect data into non-overlapping fixed-length chunks or blocks." # grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx # grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError # grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF iterators = [iter(iterable)] * n match incomplete: case "fill": return zip_longest(*iterators, fillvalue=fillvalue) case "strict": return zip(*iterators, strict=True) case "ignore": return zip(*iterators, strict=False) case _: raise ValueError("Expected fill, strict, or ignore")
[docs] def in_notebook() -> bool: """ Check if the code is being run inside a notebook. This returns False for Spyder and other IPython-based IDEs that are not Jupyter environments, even though they may have IPKernelApp. """ try: from IPython import get_ipython ipython = get_ipython() if "IPKernelApp" not in ipython.config: # pragma: no cover return False # Distinguish between Jupyter and Spyder/other IPython IDEs # In Jupyter: 'ZMQInteractiveShell' # In Spyder: 'SpyderShell' or similar # In IPython terminal: 'TerminalInteractiveShell' shell_class = ipython.__class__.__name__ # Only return True for actual Jupyter environments # ZMQInteractiveShell is used by both Jupyter Notebook and JupyterLab return shell_class == "ZMQInteractiveShell" except ImportError: return False except AttributeError: return False
[docs] def show_dataframe(df, show_index=False, max_rows=None, max_cols=None): """ Function to show a `~pandas.DataFrame` in IPython or Jupyter. Parameters ---------- df : `~pandas.DataFrame` The `~pandas.DataFrame` to be shown. show_index : bool Show the index of the `~pandas.DataFrame`. max_rows : int or None Maximum number of rows to display. max_cols : int or None Maximum number of columns to display. """ kwargs = {"max_rows": max_rows, "max_cols": max_cols, "index": show_index} if in_notebook(): display(HTML(df.to_html(**kwargs))) else: print(df.to_string(**kwargs))
[docs] def calc_vegas_spurs( vsprval: float | np.ndarray, vspdelt: float | np.ndarray, vsprpix: float | np.ndarray, maxchan: float, keep_central=False, ) -> np.ma.masked_array: """ Calculate VEGAS spur channel locations. SPUR_CHANNEL = (J-VSPRVAL)*VSPDELT+VSPRPIX - 1 where 0 <= J < 32. Spur channels are counted from zero. Parameters ---------- vsprval : float or `~numpy.ndarray` VEGAS spur channel offset. vspdelt : float or `~numpy.ndarray` VEGAS spur separation width in channels. vsprpix : float or `~numpy.ndarray` VEGAS spur reference pixel. maxchan : float Maximum channel number (counting from zero), above which calculated spurs are masked. keep_central: bool Whether to keep the central VEGAS spur location in the returned array or not. The GBO SDFITS writer by default replaces the value at the central SPUR with the average of the two adjacent channels, and hence the central channel is not typically flagged. Notes ----- All input arrays must have the same shape. Returns ------- `~numpy.ma.masked_array` The array of channel numbers where spurs occur, with shape (`N_vsp`,31) where `N_vsp` is the length of the VSP arrays. Invalid spur locations will be masked. """ NSPURS = 32 spurs = np.outer(np.arange(NSPURS), vspdelt) - vsprval * vspdelt + vsprpix - 1 if not keep_central: # GBTIDL says: The central channel is defined by NCHAN/2 when counting from zero. # Which actually is ambiguous but we will take to mean e.g. 8192 when NCHAN=16384 # This produces the same return array as GBTIDL dcspurschan.pro central = NSPURS // 2 a = np.sort(spurs[np.arange(len(spurs)) != central].astype(int)) else: a = np.sort(spurs.astype(int)) # Mask any spur locations <0 or > maxchan. # Calling functions rely on the number of rows len(a[1]) to match the length of # the input VSP values, so mask instead of dropping rows. a = np.ma.masked_array(a, mask=np.logical_or(a < 0, a > maxchan)) return a.T
[docs] def get_valid_channel_range(channel: list | np.ndarray) -> list: """ Check that a channel range (e.g., that was given to Selection) defines a contiguous range of channels and return a list of length 2 if valid. The returned list can be used in data calibration. Unlike :meth:`~dysh.util.selection.Selection.select_channel`, if the input `channel` list has two elements, it will be assumed that channel[0] is the first chan and channel[1] is the last channel. However, channel ranges specified identically to :meth:`~dysh.util.selection.Selection.select_channel` can also be used as input. Parameters ---------- channel : list|np.ndarray List of beginning and end channel Raises ------ ValueError If the input channel list cannot be converted to [first,last] Returns ------- list An inclusive list of [first_channel, last_channel] """ try: a = np.array(channel) except ValueError as ve: # it was some sort of compound list like [1,4,(30,40)] raise ValueError(f"Could not parse {channel} into a valid,contiguous channel range.") from ve ash = a.shape if ash == (1, 2): first = a[0][0] last = a[0][1] elif ash == (2,): first = a[0] last = a[1] if first > last: raise ValueError(f"In channel range {channel}, first channel is greater than last channel.") return [int(first), int(last)]
[docs] def isot_to_mjd(isot): """ Convert an ISOT string to MJD. """ EPOCH_MJD = 40587 MSEC_IN_A_DAY = 86400e3 return np.array(isot, dtype="datetime64[ms]").astype("int64") / MSEC_IN_A_DAY + EPOCH_MJD
[docs] def replace_col_astype(t: Table, colname: str, astype, fill_value): if hasattr(t[colname], "mask"): savemask = t[colname].mask.copy() else: savemask = False q = np.ma.masked_array(pd.to_numeric(t[colname]), savemask, fill_value=fill_value, dtype=astype) t[colname].fill_value = fill_value t.replace_column(colname, q) if hasattr(t, "mask"): t[colname].mask = savemask
[docs] def inner_channel_slice(nchan: int, fedge: float = 0.1) -> slice: """ Return a slice cropping `fedge` channels at each end. This is inclusive on the upper end to reproduce GBTIDL results. Parameters ---------- nchan : int Number of channels. fedge : float Fraction of edges to crop on each end. For example, `fedge=0.1` will crop 10% of the channels on each end. Returns ------- channel_slice : slice Slice that will select the inner `1-fedge` channels. """ start = round(nchan * fedge) stop = -(start - 1) return slice(start, stop, 1)