Source code for morpheus_core.morpheus_core

# MIT License
# Copyright 2020 Ryan Hausen
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ==============================================================================

import os
from functools import partial
from itertools import islice, repeat, starmap, takewhile
from typing import Callable, List, Tuple, Union

import numpy as np
from astropy.io import fits
from tqdm import tqdm

from morpheus_core.helpers import misc_helper
from morpheus_core.helpers import fits_helper
from morpheus_core.helpers import label_helper
from morpheus_core.helpers import parallel_helper

__all__ = ["AGGREGATION_METHODS", "predict"]


[docs]class AGGREGATION_METHODS: """Helper class with string constants to use as arguments in morpheus_core methods.""" MEAN_VAR = "mean_var" RANK_VOTE = "rank_vote" INVALID_ERR = " ".join( [ "Invalid aggregation method please select one of", "AGGREGATION_METHODS.MEAN_VAR or AGGREGATION_METHODS.RANK_VOTE", ] )
def build_batch( arr: List[np.ndarray], window_size: Tuple[int, int], batch_idxs: List[Tuple[int, int]], ): """Builds a batch of samples of `window_size` from `arr` at `batch_idxs`. Args: arr (List[np.ndarray]): array(s) to extract values from window_size (Tuple[int, int]): (height, width) of batch samples batch_idxs (List[Tuple[int, int]]): List of (y,x) locations to sample Returns: Returns a 2-Tuple where the first element is the batch and the second element is the list of batch idxs """ def grab_slice(in_array, dim0, dim1): return in_array[dim0 : dim0 + window_size[0], dim1 : dim1 + window_size[1], ...] def grab_batch(in_array: np.ndarray): return np.array([grab_slice(in_array, dim0, dim1) for dim0, dim1 in batch_idxs]) batches = list(map(grab_batch, arr)) return batches, batch_idxs def predict_batch( model_f: Callable, batch: List[np.ndarray], batch_idxs: List[Tuple[int, int]] ) -> Tuple[np.ndarray, List[Tuple[int, int]]]: """Calls the model function on a batch. Args: model_f (Callable): Model function that predicts on a batch batch (List[np.ndarray]): batch values batch_idxs (List[Tuple[int, int]]): (y,x) locations for batch values Returns: A 2-Tuple where the first element is the output of model on the given batch and the second element is the batch indexes associated with the output. """ return model_f(batch), batch_idxs def update_output( aggregate_method: str, update_map: np.ndarray, stride: Tuple[int, int], dilation: float, n: np.ndarray, outputs: np.ndarray, batch_out: np.ndarray, batch_idx: Tuple[int, int], ) -> None: """Updates the total output with a single output value Args: aggregate_method (str): How to process the output from the model. If AGGREGATION_METHODS.MEAN_VAR record output using mean and variance, If AGGREGATION_METHODS.RANK_VOTE record output as the normalized vote count. update_map (np.ndarray): A boolean mask that indicates what pixels in in each example to update stride (Tuple[int, int]): How many (rows, columns) to move through the image at each iteration. n (np.ndarray): The array containing the n values outputs (np.ndarray): The array containing the aggregated output values batch_out (np.ndarray): The output from the model to incorporate into outputs batch_idx (Tuple[int, int]): The (y,x) location in the larger image that the batch_out should be incorporated into Returns: None Raises: ValueError if aggregate_method is not one of AGGREGATION_METHODS.MEAN_VAR or AGGREGATION_METHODS.RANK_VOTE """ dilate_f = lambda tuple_in: tuple(map(lambda val: int(dilation * val), tuple_in)) dilated_batch_idx = dilate_f(batch_idx) dilated_stride = dilate_f(stride) if aggregate_method == AGGREGATION_METHODS.MEAN_VAR: label_helper.update_mean_var( update_map, dilated_stride, n, outputs, batch_out, dilated_batch_idx ) elif aggregate_method == AGGREGATION_METHODS.RANK_VOTE: label_helper.update_rank_vote( update_map, dilated_stride, n, outputs, batch_out, dilated_batch_idx ) else: raise ValueError(AGGREGATION_METHODS.INVALID_ERR) def udpate_batch( aggregate_method: str, update_map: np.ndarray, stride: Tuple[int, int], dilation: float, n: np.ndarray, outputs: np.ndarray, batch_out: np.ndarray, # [n, w, h, c] batch_idxs: List[Tuple[int, int]], # [n, 2] ) -> None: """Updates the total output with the batch output values Args: aggregate_method (str): How to process the output from the model. If AGGREGATION_METHODS.MEAN_VAR record output using mean and variance, If AGGREGATION_METHODS.RANK_VOTE record output as the normalized vote count. update_map (np.ndarray): A boolean mask that indicates what pixels in in each example to update stride (Tuple[int, int]): How many (rows, columns) to move through the image at each iteration. n (np.ndarray): The array containing the n values outputs (np.ndarray): The array containing the aggregated output values batch_out (np.ndarray): The output from the model to incorporate into outputs batch_idx (List[Tuple[int, int]]): A list of (y,x) locations for each of the output array in `outputs` Returns: None """ update_f = partial( update_output, aggregate_method, update_map, stride, dilation, n, outputs, ) misc_helper.apply(update_f, zip(batch_out, batch_idxs)) def predict_arrays( model: Callable, model_inputs: List[np.ndarray], n_classes: int, batch_size: int, window_shape: Tuple[int, int], dilation: float = 1, stride: Tuple[int, int] = (1, 1), update_map: np.ndarray = None, aggregate_method: str = AGGREGATION_METHODS.RANK_VOTE, out_dir: str = None, ) -> Tuple[List[fits.HDUList], List[np.ndarray]]: """Uses applies the given model on the given inputs and returns the output. Args: model (Callable): The model to apply the the inputs model_inputs (List[np.ndarray]): The input arrays to a the model as a list n_classes (int): The number of output classes batch_size (int): The number of examples to include in each batch window_shape (int): The (height, width) of the samples to extract stride (Tuple[int, int]): How many (rows, columns) to move through the image at each iteration. update_map (np.narray): A 2D array of the same size as window height that indicates which pixels to use to updates for each example aggregate_method (str): How to process the output from the model. If AGGREGATION_METHODS.MEAN_VAR record output using mean and variance, If AGGREGATION_METHODS.RANK_VOTE record output as the normalized vote count. out_dir (str): Where to store the output arrays """ model_inputs = list(map(np.atleast_3d, model_inputs)) in_shape = model_inputs[0].shape[:-1] valid_dilation_f = lambda _, y: y > 1 or not bool(y % float(1)) if not all(starmap(valid_dilation_f, zip(in_shape, repeat(dilation)))): raise ValueError("Invalid dilation value.") out_shape = [*list(map(lambda x: int(x * dilation), in_shape)), n_classes] out_dir_f = lambda s: os.path.join(out_dir, s) if out_dir else None if update_map is None: update_map = np.ones(list(map(lambda x: x * dilation, window_shape))) if aggregate_method == AGGREGATION_METHODS.MEAN_VAR: hdul_lbl, arr_lbl = label_helper.get_mean_var_array( out_shape, out_dir_f("output.fits") ) elif aggregate_method == AGGREGATION_METHODS.RANK_VOTE: hdul_lbl, arr_lbl = label_helper.get_rank_vote_array( out_shape, out_dir_f("output.fits") ) else: raise ValueError(AGGREGATION_METHODS.INVALID_ERR) hdul_n, arr_n = label_helper.get_n_array(out_shape[:-1], out_dir_f("n.fits")) indicies = label_helper.get_windowed_index_generator(in_shape, window_shape, stride) window_dim0, window_dim1 = window_shape stride_dim0, stride_dim1 = stride num_idxs = ((in_shape[0] - window_dim0 + 1) // stride_dim0) * ( (in_shape[1] - window_dim1 + 1) // stride_dim1 ) pbar = tqdm(total=num_idxs // batch_size, desc="classifying", unit="batch") batch_generator = (list(islice(indicies, batch_size)) for _ in repeat(None)) batch_indices = takewhile(lambda x: len(x) > 0, batch_generator) batch_func = partial(build_batch, model_inputs, window_shape) batches_and_idxs = map(batch_func, batch_indices) classify_func = partial(predict_batch, model) # TODO: Implement an async queue system for predicting and updating results async_update = False if async_update: pass else: update_func = partial( udpate_batch, aggregate_method, update_map, stride, dilation, arr_n, arr_lbl, ) for _ in starmap(update_func, starmap(classify_func, batches_and_idxs)): pbar.update() hduls = [hdul_lbl, hdul_n] outputs = [arr_lbl, arr_n] return hduls, outputs
[docs]def predict( model: Callable, model_inputs: List[Union[np.ndarray, str]], n_classes: int, batch_size: int, window_shape: Tuple[int, int], dilation: float = 1, stride: Tuple[int, int] = (1, 1), update_map: np.ndarray = None, aggregate_method: str = AGGREGATION_METHODS.RANK_VOTE, out_dir: str = None, gpus: List[int] = None, cpus: int = None, parallel_check_interval: float = 1, ) -> Tuple[List[fits.HDUList], List[np.ndarray]]: """Applies the `model` the `model_inputs` If you are using the parallel functionality, then `model` must be pickleable. Args: model (Callable): The model to apply to the inputs model_inputs (List[Union[np.ndarray, str]]): The inputs to classify using the given `model` n_classes (int): The number of classes that are output batch_size (int): The number of examples to include in a batch window_shape (int): The (height, width) of the samples to extract stride (Tuple[int, int]): How many (rows, columns) to move through the image at each iteration. update_map (np.narray): A 2D array of the same size as window height that indicates which pixels to use to updates for each example aggregate_method (str): How to process the output from the model. If AGGREGATION_METHODS.MEAN_VAR record output using mean and variance, If AGGREGATION_METHODS.RANK_VOTE record output as the normalized vote count. out_dir (str): The directory to save output files in if the `model_inputs` are string locations. gpus (List[int]): The gpu ids to use for parallel processesing cpus (int): The number of cpus to use for parllel processing Returns: A 2-Tuple where the first element is the list of fits.HDULS for the outputfiles. The second element is a list of the output arrays from the model given the the input arrays. Raises: ValueError if `model_inputs` are not all of the same type ValueError if `model_inputs` are not str or np.ndarray ValueError if both gpus and cpus are given ValueError is cpus or gpus are given, but out_dir is not given ValueError if len(gpus)==1 ValueError if cpus<2 """ inputs_are_str = misc_helper.vaidate_input_types_is_str(model_inputs) workers, is_gpu = misc_helper.validate_parallel_params(gpus, cpus, out_dir) if inputs_are_str: in_hduls, inputs = fits_helper.open_files(model_inputs, "readonly") else: in_hduls, inputs = [], model_inputs if len(workers) == 1: out_hduls, outputs = predict_arrays( model, inputs, n_classes, batch_size, window_shape, dilation, stride, update_map, aggregate_method, out_dir, ) else: parallel_helper.build_parallel_classification_structure( model, inputs, model_inputs, n_classes, batch_size, window_shape, dilation, stride, update_map, aggregate_method, out_dir, workers, ) parallel_helper.run_parallel_jobs( workers, is_gpu, out_dir, parallel_check_interval ) out_hduls, outputs = parallel_helper.stitch_parallel_classifications( workers, out_dir, aggregate_method, window_shape ) misc_helper.apply(lambda hdul: hdul.close(), in_hduls) return out_hduls, outputs