Array API Prototype Demo

A guide to LIGO-Virgo detector noise

This notebook is modified and built on top of the work by GWOSC (GitHub). We would like to credit the authors of the paper and Gravitational Wave Open Science Center Tutorials for their work and open sourcing the code for generating the analysis. This notebook is intended as an educational demonstration to reproduce some of the figures in the paper “A guide to LIGO-Virgo detector noise and extraction of transient gravitational-wave signals”.

PyTorch Tensors with SciPy?

We aim to showcase an interoperable array protocol with this “Guide to GW detections and noise” real world example. The goal with this tutorial examples is to show the usage of PyTorch with an already existing codebase written SciPy or any other array/tensor consuming library for that matter built around NumPy. We would like to achieve this without making a lot of amends.

drawing drawing drawing drawing drawing

Technically, the only changes involved for a NumPy end user: np. -> xp.. Where xp represents any array/tensor implementation compliant with the Array API.

We explain more about the motivation for making array libraries interoperable in this blog.

Download data files

Download each of the following files into the directory where you would like to run the notebook. Most of these data files are described at:

https://www.gw-openscience.org/events/GW150914/

# -- Helper python module to open LIGO data files
! wget -q -N https://www.gw-openscience.org/static/sample_code/readligo.py 
    
# -- Download 4096 second data files around GW150914
! wget -q -N https://www.gw-openscience.org/GW150914data/L-L1_LOSC_4_V2-1126257414-4096.hdf5
! wget -q -N https://www.gw-openscience.org/GW150914data/H-H1_LOSC_4_V2-1126257414-4096.hdf5
    
# -- Download 32 second data files around GW150914
! wget -q -N https://www.gw-openscience.org/GW150914data/H-H1_LOSC_4_V2-1126259446-32.hdf5
! wget -q -N https://www.gw-openscience.org/GW150914data/L-L1_LOSC_4_V2-1126259446-32.hdf5
    
# -- Download waveform template
! wget -q -N https://www.gw-openscience.org/GW150914data/P150914/fig2-unfiltered-waveform-H.txt
    
# -- Download data from Figure 1 of GW150914 detection paper
! wget -q -N https://www.gw-openscience.org/GW150914data/P150914/fig1-observed-H.txt
! wget -q -N https://www.gw-openscience.org/GW150914data/P150914/fig1-residual-H.txt
! wget -q -N https://www.gw-openscience.org/GW150914data/P150914/fig1-observed-L.txt
! wget -q -N https://www.gw-openscience.org/GW150914data/P150914/fig1-residual-L.txt
# Standard python numerical analysis imports:
import numpy as np
import torch

# https://github.com/numpy/numpy/pull/18585/
# import numpy.array_api as xp 

import math
import scipy
from scipy import signal
from scipy.interpolate import interp1d  # Not compatible yet
from scipy.signal import butter, filtfilt  # Not compatible yet

# !pip install h5py
import h5py
import json

# plotting items
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab

# -- Set some parameters that control how plots are displayed
plt.rcParams['figure.figsize'] = [5, 3]
plt.rcParams['lines.linewidth'] = .75

# LIGO-specific readligo.py 
# Download from https://www.gw-openscience.org/static/sample_code/readligo.py
# Note: This has been edited for the PyTorch demo with array-api available at,
#       https://github.com/AnirudhDagar/array-api-demo/blob/master/readligo.py
import readligo as rl

Array-API Demo Background

If you are not a maintainer/core dev at some array consumer library like SciPy, EinOps etc., feel free to skip this section. Through this section, we’d like to explain the very simple machinery and mechanics that goes behind making something like this possible.

__array_namespace__

Let’s understand about the __array_namespace__ protocol first. More details in NEP 47. This is another magic method in python which can be attached to any tensor/array object to represent compliance for some specific Array API version. In the example below we specifically create a dummy __array_namespace__ method and monkey-patch it to PyTorch’s Tensor class. Something like this is only for demonstration purposes and will not be needed in a future released version 1.11 for PyTorch.

Note: This is needed since Array API in PyTorch is under development. See the progress: module python array api.

from typing import Optional

def __array_namespace__(Array, /, *, api_version: Optional[str] = None) -> object:
    if api_version is not None:
        raise ValueError("Unrecognized array API version")
    import torch
    return torch

# Monkey-Patch the protocol to torch Tensor class
setattr(torch.Tensor, '__array_namespace__', __array_namespace__)

get_namespace

get_namespace is the key method which will be required inside all the consumer libraries looking to support multiple array/tensor implementations. The idea is to check the passed array/tensor objects for the protocol compliance and accordingly make sure the same implementation is used going ahead by returning the array/tensor framework namespace represented in the form xp.

def get_namespace(*xs):
    # `xs` contains one or more arrays, or possibly Python scalars (accepting
    # those is a matter of taste, but doesn't seem unreasonable).
    namespaces = {
        x.__array_namespace__() if hasattr(x, '__array_namespace__')
        else None for x in xs if not isinstance(x, (bool, int, float, complex))
    }
    if not namespaces:
        raise ValueError("Unrecognized array input")
    if len(namespaces) != 1:
        raise ValueError(f"Multiple namespaces for array inputs: {namespaces}")
    xp, = namespaces
    if xp is None:
        raise ValueError("The input is not a supported array type")

    if xp == torch:
        # Monkey-Patch functions (Currently under progress for PyTorch)
        xp.asarray = torch.as_tensor
        xp.pi = math.pi

    return xp
# Test array-api `get_namespace`
t = torch.arange(3)
xp = get_namespace(t)

print(xp)
<module 'torch' from '/Users/gollum/anaconda3/envs/scipy-dev/lib/python3.9/site-packages/torch/__init__.py'>

As seen above get_namespace works as expected. Before we dive deeper and move into the real deal, I warn you that this notebook works for a specific fork of SciPy right now and is just a prototype showing the ability to use PyTorch tensors with SciPy. It may not work for you with a release version of SciPy.

If you really want to check the details of the changes made in SciPy to make this work, we encourage you to checkout and build the branch array-api-demo locally. Some functions like scipy.signal.welch have been adapted in SciPy such that they can support multiple array/tensor libraries. In this case PyTorch and Numpy. Feel free to explore more in spectral.py.

Let’s get started.

Calculate the noise power spectral density (PSD)

# these 4096 second files are useful for computing more accurate psds
large_data_filenames = {'GW150914': 'LOSC_4_V2-1126257414-4096.hdf5'}
fn_H1 = 'H-H1_' + large_data_filenames['GW150914']
fn_L1 = 'L-L1_' + large_data_filenames['GW150914']

# this is the approximate event time which is used in the papers
time_center = 1126259462

# get sample rate from the H1 data file
with h5py.File(fn_H1, 'r') as hdf_file:
    dt = xp.asarray((hdf_file['strain/Strain'].attrs['Xspacing']))
fs = int(1.0/dt)
print("Using sample rate of {0} Hz found in file: {1}".format(fs, fn_H1))
Using sample rate of 4096 Hz found in file: H-H1_LOSC_4_V2-1126257414-4096.hdf5
def get_full_psds(eventnames, large_data_filenames,
                  make_plots=False, plot_others=False):
    """Obtains full 1024 second psds for all the events specified. Uses the Welch
    average technique, along with other less accurate techniques if
    specified. Can also plot the psd obtained.
    
    Args: 
        eventnames (list): list of events to get psds for
        large_datafilenames (dict): dictionary whose keys are the eventnames 
            and whose values are the filenames of the large amounts of strain
            data used, without the added 'H-<det>_'
        make_plots (bool, optional): if set to True, plot psd data
        plot_others (bool, optional): if set to True, also obtain psd data 
            without averaging as well as with no window
    
    Returns:
        dict: A dictionary containing psds for each detector for each event 
            specified in eventnames.
    """

    large_data_psds = {}
    for eventname in eventnames:
        large_data_psds[eventname] = {'H1': [], 'L1': []}

        # get filename
        fn_H1 = 'H-H1_' + large_data_filenames[eventname]
        fn_L1 = 'L-L1_' + large_data_filenames[eventname]

        # get strain data
        # set the use_torch argument for loading data in pytorch tensors
        strain_H1, time_H1, chan_dict_H1 = rl.loaddata(fn_H1, 'H1', use_torch=True)
        strain_L1, time_L1, chan_dict_L1 = rl.loaddata(fn_L1, 'L1', use_torch=True)
        
        # validate the data loaded is actually as torch tensors
        print("Array/Tensor Implementation Used: ", type(strain_H1), type(time_H1))
        print("Array/Tensor Implementation Used: ", type(strain_L1), type(time_L1))

        # both H1 and L1 will have the same time vector, so:
        time = time_H1

        indxt_around = xp.where((time >= time_center - 512) & (
            time < time_center + 512))
        

        # number of sample for the fast fourier transform:
        NFFT = 4 * fs           # Use 4 seconds of data for each fourier transform
        NOVL = 1 * NFFT / 2     # The number of points of overlap between segments used in Welch averaging
        psd_window = scipy.signal.tukey(NFFT, alpha=1./4, use_torch=True)


        freqs, Pxx_H1 = scipy.signal.welch(strain_H1[indxt_around], fs=fs, nfft=NFFT,
                                 window=psd_window, noverlap=NOVL)
        freqs, Pxx_L1 = scipy.signal.welch(strain_L1[indxt_around], fs=fs, nfft=NFFT, 
                                window=psd_window, noverlap=NOVL)                

        if (plot_others):
            # smaller window if we're not doing Welch's method
            short_indxt_away = xp.where((time >= time_center - 2) & (
                time < time_center + 2))

            tukey_freqs, tukey_Pxx_H1 = scipy.signal.welch(strain_H1[short_indxt_away],
                                                           fs=fs, nfft=NFFT, window=psd_window)
            
            nowin_freqs, nowin_Pxx_H1 = scipy.signal.welch(strain_H1[short_indxt_away],
                                                           fs=fs, nfft=NFFT, window=xp.ones(NFFT))

        # We will use interpolations of the PSDs computed above for whitening:
        psd_H1 = interp1d(freqs.numpy(), Pxx_H1.numpy())
        psd_L1 = interp1d(freqs.numpy(), Pxx_L1.numpy())
        
        large_data_psds[eventname]['H1'] = psd_H1
        large_data_psds[eventname]['L1'] = psd_L1

        if make_plots:
            plt.figure(figsize=(8, 5))
            # scale x and y axes
            plt.xscale('log', base=2)
            plt.yscale('log', base=10)

            # plot nowindow, tukey, welch together 
            plt.plot(nowin_freqs.numpy(), nowin_Pxx_H1.numpy(),'purple',label= 'No Window',
                     alpha=.8, linewidth=.5)
            plt.plot(tukey_freqs.numpy(), tukey_Pxx_H1.numpy(),'green',label='Tukey Window',
                     alpha=.8, linewidth=.5)
            plt.plot(freqs.numpy(), Pxx_H1.numpy(),'black',label='Welch Average', alpha=.8,
                     linewidth=.5)

            # plot 1/f^2
            # give it the right starting scale to fit with the rest of the plots
            # don't include zero frequency
            inverse_square = xp.asarray(list(map(lambda f: 1 / (f**2), 
                                    nowin_freqs[1:])))
            # inverse starts at 1 to take out 1/0
            scale_index = 500 # chosen by eye to fit the plot
            scale = nowin_Pxx_H1[scale_index]  / inverse_square[scale_index]
            plt.plot(nowin_freqs[1:].numpy(), inverse_square.numpy() * scale.numpy(),'red',
                     label= r'$1 / f^2$', alpha=.8, linewidth=1)

            plt.axis([20, 512, 1e-48, 1e-41])
            plt.ylabel('Sn(t)')
            plt.xlabel('Freq (Hz)')
            plt.legend(loc='upper center')
            plt.title('LIGO PSD data near ' + eventname + ' at H1')
            plt.show()

    return large_data_psds
eventnames = ['GW150914']
large_data_psds = get_full_psds(eventnames, large_data_filenames,
                                make_plots=True, plot_others=True)
Array/Tensor Implementation Used:  <class 'torch.Tensor'> <class 'torch.Tensor'>
Array/Tensor Implementation Used:  <class 'torch.Tensor'> <class 'torch.Tensor'>
_images/GW_Demo_Array_API_15_1.png

Note that inadequate windowing of these strongly colored data produces a psd that is entirely dominated by “spectral leakage”, https://en.wikipedia.org/wiki/Spectral_leakage, and inadequate averaging leads to noise fluctuations that contaminate the estimated PSD.

Whiten and Bandpass Functions

Now we’ll create some helper functions to whiten and bandpass data within a given frequency band. These allow us to better see some signal in our data:

def whiten(strain, interp_psd, dt, phase_shift=0, time_shift=0):
    """Whitens strain data given the psd and sample rate, also applying a phase
    shift and time shift.

    Args:
        strain (ndarray): strain data
        interp_psd (interpolating function): function to take in freqs and output 
            the average power at that freq 
        dt (float): sample time interval of data
        phase_shift (float, optional): phase shift to apply to whitened data
        time_shift (float, optional): time shift to apply to whitened data (s)
    
    Returns:
        ndarray: array of whitened strain data
    """
    Nt = len(strain)
    # take the fourier transform of the data
    freqs = xp.fft.rfftfreq(Nt, dt)

    # whitening: transform to freq domain, divide by square root of psd, then
    # transform back, taking care to get normalization right.
    hf = xp.fft.rfft(strain)
    
    # apply time and phase shift
    hf = hf * xp.exp(-1.j * 2 * xp.pi * time_shift * freqs - 1.j * phase_shift)
    norm = 1./xp.sqrt(1./(dt*2))
    # interp1d needs array api compatibility
    white_hf = hf / xp.sqrt(torch.from_numpy(interp_psd(freqs))) * norm
    white_ht = xp.fft.irfft(white_hf, n=Nt)
    return white_ht

def bandpass(strain, fband, fs):
    """Bandpasses strain data using a butterworth filter.
    
    Args:
        strain (ndarray): strain data to bandpass
        fband (ndarray): low and high-pass filter values to use
        fs (float): sample rate of data
    
    Returns:
        ndarray: array of bandpassed strain data
    """
    # TODO: Figure way to make butter compatible
    bb, ab = butter(4, [fband[0]*2./fs, fband[1]*2./fs], btype='band')
    normalization = math.sqrt((fband[1]-fband[0])/(fs/2))
    # TODO: Figure way to make filtfilt compatible
    strain_bp = filtfilt(bb, ab, strain.numpy()) / normalization
    return torch.from_numpy(strain_bp)

Note that some of the functions here are still not array api compatible. Eg. butter, filtfilt etc. These involve thousands of LOC and we decided it wasn’t worth refactoring for the purpose of this demo.

This also brings us to one major limitation of Array API, i.e. it is designed to work only for pure python implementations of a function. We can’t support multiple array/tensor implementations using the __array_namespace__ protocol if the function itself is written in C, C++, Fortran, Cython etc which is not very rare for a library like SciPy.

Other backend handling protocols like uarray could work better in such a case.

Plot strain data that has been windowed, bandpassed, and whitened

def plot_strain_data(fn_H1, fn_L1, eventname):
    """Plots windowed, whitened, and bandpassed strain vs time near a given
    event.

    Args:
        fn_H1: name of H1 data file
        fn_L1: name of L1 data file
        eventname: name of the event
    """
    
    fband = [35.0, 350.0]

    strain_H1, time_H1, chan_dict_H1 = rl.loaddata(fn_H1, 'H1', use_torch=True)
    strain_L1, time_L1, chan_dict_L1 = rl.loaddata(fn_L1, 'L1', use_torch=True)

    # both H1 and L1 will have the same time vector
    time = time_H1
    
    # create our 4 second data window
    window_len = 4*fs
    dwindow = scipy.signal.tukey(window_len, alpha=1./4, use_torch=True)

    # plot original strain data
    # only care about 4s around event
    plt.figure(figsize=(8, 8))
    plt.subplot(4, 1, 1)
    indxt = xp.where((time >= time_center - 2) & (time < time_center + 2))
    plt.xlim([-2, 2])
    plt.plot(time[indxt] - time_center, strain_H1[indxt], 'blue', 
             label='Hanford Data', linewidth=.5)
    plt.legend()

    # plot windowed data
    plt.subplot(4, 1, 2)
    strain_windowed = dwindow * strain_H1[indxt]
    plt.xlim([-2, 2])
    plt.plot(time[indxt].numpy() - time_center, strain_windowed, 'green', 
             label='Windowed Data', linewidth=.5)
    plt.legend()
    
    
    # plot whitened data
    plt.subplot(4, 1, 3)
    strain_whitened = whiten(strain_windowed, 
                             large_data_psds[eventname]['H1'], dt)
    plt.plot(time[indxt] - time_center, strain_whitened, 'red', 
             label='Whitened Data', linewidth=.5)
    plt.legend()

    # plot bandpassed data
    plt.subplot(4, 1, 4)
    strain_bp = bandpass(strain_whitened, fband, fs)
    plt.xlim([-2, 2])
    plt.plot(time[indxt].numpy() - time_center, strain_bp.numpy(), 'black', 
             label='Bandpassed Data', linewidth=.5)

    plt.legend()
    plt.yticks([-6, -3, 0, 3, 6, 9])
    plt.tight_layout()
    plt.xlim([-2, 2])
    plt.ylim([-8, 8])
    plt.show()
plot_strain_data(fn_H1, fn_L1, 'GW150914')
_images/GW_Demo_Array_API_22_0.png

Method for getting whitened/banspassed event data

def get_strain_whitenbp_data(fn_H1, fn_L1, fband, eventname):
    
    # get strain data
    strain_H1, time_H1, chan_dict_H1 = rl.loaddata(fn_H1, 'H1', use_torch=True)
    strain_L1, time_L1, chan_dict_L1 = rl.loaddata(fn_L1, 'L1', use_torch=True)

    # both H1 and L1 will have the same time vector, so:
    time = time_H1

    # whiten, bandpass the data
    strain_H1_whiten = whiten(strain_H1, large_data_psds[eventname]['H1'], dt)
    strain_L1_whiten = whiten(strain_L1, large_data_psds[eventname]['L1'], dt)

    strain_H1_whitenbp = bandpass(strain_H1_whiten, fband, fs)
    strain_L1_whitenbp = bandpass(strain_L1_whiten, fband, fs)
    
    # return results as a dictionary for more intuitive access
    total_data = {'H1': {'strain': strain_H1, 'strain_whiten': strain_H1_whiten,
                         'strain_whitenbp': strain_H1_whitenbp}, 'L1': 
                  {'strain': strain_L1, 'strain_whiten': strain_L1_whiten, 
                   'strain_whitenbp': strain_L1_whitenbp}, 'time': time, 'dt': dt}
                   
    return total_data
fn_H1_32 = 'H-H1_LOSC_4_V2-1126259446-32.hdf5'
fn_L1_32 = 'L-L1_LOSC_4_V2-1126259446-32.hdf5'
total_GW150914_data = get_strain_whitenbp_data(fn_H1_32, fn_L1_32, [35., 350.], 'GW150914')
def plot_fourier_phases(data, time): 
    """Plots Fourier phases of strain data.

    Args:
        data (ndarray): strain data
        time (ndarray): time corresponding to data.
    """
    # do this both with and without a spectral window
    plt.figure(figsize=(8, 5))
    tukey_window = signal.tukey(data.numel(), alpha=1./4, use_torch=True)
    labels = ['Tukey window', 'No window']
    colors = ['red', 'blue']
    
    # apply each window and find phases of the fft
    for i, dwindow in enumerate([tukey_window, 1]):
        data_fft = xp.fft.rfft((data * dwindow).flatten()) * dt
        datafreq = xp.fft.rfftfreq(data.numel()) / dt
        phases = xp.angle(data_fft)
        # convert angles to range (0, 2 pi)
        phases = (phases + 2 * xp.pi) % (2 * xp.pi)
        plt.plot(datafreq, phases, '.', label=labels[i], 
                 color=colors[i], markersize=1)
        
    plt.xlim([30, 400])
    plt.ylim([0, 6])
    plt.xlabel('f(Hz)')
    plt.ylabel('Phase')
    plt.legend(loc='upper right', fontsize=7)
indxt = xp.where((total_GW150914_data['time'] >= time_center - 2) & (
    total_GW150914_data['time'] < time_center + 2))
plot_fourier_phases(total_GW150914_data['H1']['strain'][indxt],
                    total_GW150914_data['time'][indxt])
_images/GW_Demo_Array_API_27_0.png

Note that the fourier phases for properly windowed data (random noise) are random between [0, 2*pi]; by contrast, un-windowed data produce strongly correlated fourier phases that are artifacts of the abrupt beginning and end of the data stretch.

Summary Array API

As seen from the analysis plots above, we try to use PyTorch tensors with an array consuming library like SciPy which was initially built only keeping NumPy in mind. With the simple Python Array API tweaks we were able to do that. This simply means that any project which is written in pure python and consumes an array api compliant framework, would be able to easily support multiple kinds of array/tensors with their modules.

This also means that switching an implementation should be extremely easy for an end user since the code will be untouched. We’ve replaced everything with an xp namespace that in theory might represent any array/tensor implementation the user wants.

Obviously not everything is rainbows and sunshine there are a few places where Array API is limited with it’s power. Neither does it aim, nor it will ever handle compiled code. The only suggestion made by array api is to follow a standard and with that little change we were to able to make a lot of interesting in roads in the very interesting direction of Array Libraries Interoperability.

Also we cannot change the way matplotlib generates the plots. One will still need to convert their arrays/tensors to NumPy ndarray before fetching the plotting data into matplotlib.

Through this prototype demo we show how modules like scipy.signal can be changed and updated to support Array API. Functions like scipy.signal.welch, scipy.signal.csd, scipy.signal.tukey etc. showcase the power of the Python Array API standard.

The spec is still under development and a big team is working towards its success. This demo and prototype will become more solid as we move towards the first spec release of the standard (expected end of 2021). Till then I hope the demo is comprehensive enough for people to understand one of the goals/use-cases and how they will be achieved with Python Array API.