Source code for dysh.util.core

"""
Core utility definitions, classes, and functions
"""

import hashlib
import re
import sys
from pathlib import Path
from typing import Union

# import astropy.units as u
import numpy as np

# import pandas as pd
from astropy.time import Time

ALL_CHANNELS = "all channels"


[docs] def select_from(key, value, df): """ Select data where key=value. Parameters ---------- key : str The key value (SDFITS column name) value : any The value to match df : `~pandas.DataFrame` The DataFrame to search Returns ------- df : `~pandas.DataFrame` The subselected DataFrame """ # nb this fails if value is None return df[(df[key] == value)]
[docs] def eliminate_flagged_rows(df, flag): """ Remove rows from an index (selection) where all channels have been flagged. Parameters ---------- df : `~pandas.DataFrame` The input dataframe from which flagged rows will be removed. flag : `~pandas.DataFrame` The flag dataframe. Should be the result of e.g. `~util.Flag.final` Returns ------- A data frame which is the input data frame with flagged rows removed. """ if len(flag) > 0: # in the final flagging selection any rows that have CHAN=ALL_CHANNELS # indicate that the entire row is flagged ff = flag[flag["CHAN"].isin([ALL_CHANNELS])] flagged_rows = set(ff["ROW"]) if len(flagged_rows) > 0: userows = list(set(df["ROW"]) - flagged_rows) if len(userows) > 0: return df[df["ROW"].isin(userows)] else: return df.iloc[0:0] # all rows removed return df
[docs] def indices_where_value_changes(colname, df): """ Find the `~pandas.DataFrame` indices where the value of the input column name changes. Parameters ---------- colname : str The column name to query. df : `~pandas.DataFrame` The DataFrame to search Returns ------- indices : ~numpy.ndarray The indices of the Dataframe where `colname` changes value. """ # @todo add option to return changing values along with index # e.g., [["A",0],["B",125],["C",246]] # This is some super panda kung-fu. # See https://stackoverflow.com/questions/48673046/get-index-where-value-changes-in-pandas-dataframe-column if colname not in df: raise KeyError(f"Column {colname} not in input DataFrame") # df.shift() shifts the index by one, so we are then comparing df[N] to df[N-1]. This gets us # a truth table of where values change. We filter on colname, then return a list of indices # where the value is true. Finally, we squeeze out the empty dimensions of the np array. ary = df.ne(df.shift()).filter(items=[colname]).apply(lambda x: x.index[x].tolist()).values return np.squeeze(ary, axis=1)
[docs] def gbt_timestamp_to_time(timestamp): """Convert the GBT sdfits timestamp string format to an :class:`~astropy.time.Time` object. GBT SDFITS timestamps have the form YYYY_MM_DD_HH:MM:SS in UTC. Parameters ---------- timestamp : str or list-like The GBT format timestamp as described above. If str, a Time object containing a single time is returned. If list-like, a Time object containing multiple UTC times is returned. Returns ------- time : `~astropy.time.Time` The time object """ # convert to ISO FITS format YYYY-MM-DDTHH:MM:SS(.SSS) if isinstance(timestamp, str): t = timestamp.replace("_", "-", 2).replace("_", "T") else: t = [ts.replace("_", "-", 2).replace("_", "T") for ts in timestamp] return Time(t, scale="utc")
[docs] def generate_tag(values, hashlen, add_time=True): """ Generate a unique tag based on input values. A hash object is created from the input values using SHA256, and a hex representation is created. The first `hashlen` characters of the hex string are returned. Parameters ---------- values : array-like The values to use in creating the hash object hashlen : int, optional The length of the returned hash string. add_time: bool Add the time of the call to the values for hash generation. Returns ------- tag : str The hash string """ if add_time: values.append(Time.now().value) data = "".join(map(str, values)) hash_object = hashlib.sha256(data.encode()) unique_id = hash_object.hexdigest() return unique_id[0:hashlen]
[docs] def consecutive(data, stepsize=1): """Returns the indices of elements in `data` separated by less than stepsize separated into groups. Parameters ---------- data : array Array with values to split. stepsize : int Maximum separation between elements of `data` to be considered a single group. Returns ------- groups : `~numpy.ndarray` Array with values of `data` separated into groups. """ return np.split(data, np.where(np.diff(data) >= stepsize)[0] + 1)
[docs] def sq_weighted_avg(a, axis=0, weights=None): # @todo make a generic moment or use scipy.stats.moment r"""Compute the mean square weighted average of an array (2nd moment). :math:`v = \sqrt{\frac{\sum_i{w_i~a_i^{2}}}{\sum_i{w_i}}}` Parameters ---------- a : `~numpy.ndarray` The data to average axis : int The axis over which to average the data. Default: 0 weights : `~numpy.ndarray` or None The weights to use in averaging. The weights array must be the length of the axis over which the average is taken. Default: `None` will use equal weights. Returns ------- average : `~numpy.ndarray` The average along the input axis """ if weights is None: w = np.ones_like(a) else: w = weights v = np.sqrt(np.average(a * a, axis=axis, weights=w)) return v
[docs] def get_project_root() -> Path: """ Returns the project root directory. """ return Path(__file__).parent.parent.parent.parent
[docs] def get_project_testdata() -> Path: """ Returns the project testdata directory """ return get_project_root() / "testdata"
[docs] def get_project_configuration() -> Path: """ Returns the directory where dysh configuration files are kept. Returns ------- Path The project configuration directory. """ return get_project_root() / "conf"
[docs] def get_size(obj, seen=None): """Recursively finds size of objects. See https://goshippo.com/blog/measure-real-size-any-python-object/ """ size = sys.getsizeof(obj) if seen is None: seen = set() obj_id = id(obj) if obj_id in seen: return 0 # Important mark as seen *before* entering recursion to gracefully handle # self-referential objects seen.add(obj_id) if isinstance(obj, dict): size += sum([get_size(v, seen) for v in obj.values()]) size += sum([get_size(k, seen) for k in obj.keys()]) elif hasattr(obj, "__dict__"): size += get_size(obj.__dict__, seen) elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, bytearray)): size += sum([get_size(i, seen) for i in obj]) return size
[docs] def minimum_string_match(s, valid_strings): """ return the valid string from a list, given a minimum string input Example: minimum_string_match('a',['alpha','beta','gamma']) returns: 'alpha' Parameters ---------- s : string string to use for minimum match valid_strings : list of strings list of full strings to min match on Returns ------- string matched string, if one is found. Otherwise "None" is returned. """ n = len(valid_strings) m = [] for i in range(n): if valid_strings[i].find(s) == 0: m.append(i) if len(m) == 1: return valid_strings[m[0]] return None
[docs] def uniq(seq): """Remove duplicates from a list while preserving order. from http://stackoverflow.com/questions/480214/how-do-you-remove-duplicates-from-a-list-in-python-whilst-preserving-order """ seen = set() seen_add = seen.add return [x for x in seq if x not in seen and not seen_add(x)]
[docs] def keycase(d, case="upper"): """ Change the case of dictionary keys Parameters ---------- d : dict The input dictionary case : str, one of 'upper', 'lower' Case to change keys to The default is "upper". Returns ------- newDict : dict A copy of the dictionary with keys changed according to `case` """ if case == "upper": newDict = {k.upper(): v for k, v in d.items()} elif case == "lower": newDict = {k.lower(): v for k, v in d.items()} return newDict
# Example of logging a function call # @log_function_call(log_level="debug")
[docs] def powerof2(number): """ Computes the closest power of 2 for a given `number`. Parameters ---------- number : float number to determine the closest power of 2. Returns ------- pow2 : int the closest power of 2. """ return round(np.log10(number) / np.log10(2.0))
# From astropy.io.fits.Card: # FSC commentary card string which must contain printable ASCII characters. # Note: \Z matches the end of the string without allowing newlines _ascii_text_re = re.compile(r"[ -~]*\Z") def _ensure_ascii_str(text: str, check: bool = False) -> str: """does the actual cleaning of a text string""" clean_text = text.encode("ascii", "ignore").decode("ascii") clean_text = clean_text.replace("\n", " ") if check and _ascii_text_re.match(clean_text) is None: raise ValueError(f"Unable to fully clean string:{clean_text!r} of non-ASCII or non-printable characters.") return clean_text
[docs] def ensure_ascii(text: Union[str, list[str]], check: bool = False) -> Union[str, list[str]]: """ Remove non-printable ASCII characters from a string or list of strings. This is to ensure that FITS cards conform to the standard Parameters ---------- text : str The text to clean check: bool Check if the clean value is truly clean according to astropy FITS, raise ValueError if not Returns ------- str or list[str] The cleaned text """ if isinstance(text, str): return _ensure_ascii_str(text) else: clean_text = [] for c in text: clean_text.append(_ensure_ascii_str(c)) return clean_text
[docs] def convert_array_to_mask(a, length, value=True): """ This method interprets a simple or compound array and returns a numpy mask of length `length`. Single arrays/tuples will be treated as element index lists; nested arrays will be treated as *inclusive* ranges, for instance: `` # mask elements 1 and 10 convert_array_to_mask([1,10]) # mask elements 1 thru 10 inclusive convert_array_to_mask([[1,10]]) # mask ranges 1 thru 10 and 47 thru 56 inclusive, and element 75 convert_array_to_mask([[1,10], [47,56], 75)]) # tuples also work, though can be harder for a human to read convert_array_to_mask(((1,10), [47,56], 75)) `` Parameters ---------- a : number or array-like The length : int The length of the mask to return, e.g. the number of channels in a spectrum. value : bool The value to fill the mask with. True to mask data, False to unmask. Returns ------- mask : ~np.ndarray A numpy array where the mask is True according to the rules above. """ if a == ALL_CHANNELS: return np.full(length, value) mask = np.full(length, False) for v in a: if isinstance(v, (tuple, list, np.ndarray)) and len(v) == 2: # If there are just two numbers, interpret is as an inclusive range mask[v[0] : v[1] + 1] = value else: mask[v] = value return mask
[docs] def abbreviate_to(length, value, squeeze=True): """ Abbreviate a value for display in limited space. The abbreviated value will have initial characters, ellipsis, and final characters, e.g. '[(a,b),(c,d)...(w,x),(y,z)]'. Parameters ---------- length : int Maximum string length. value : any The value to be abbreviated. squeeze : bool, optional Squeeze blanks. If True, replace ", " (comma space) with "," (comma). The default is True. Returns ------- strv : str Abbreviated string representation of the input value """ strv = str(value) if squeeze: strv = strv.replace(", ", ",") if len(strv) > length: bc = int(length / 2) - 1 ec = bc - 1 strv = strv[0:bc] + "..." + strv[-ec:] return strv