#!python
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
#   See COPYING file distributed along with the NiBabel package for the
#   copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""
Output a summary table for neuroimaging files (resolution, dimensionality, etc.)
"""
from __future__ import division, print_function, absolute_import

import re
import sys

import numpy as np
import nibabel as nib

from math import ceil
from collections import defaultdict
from optparse import OptionParser, Option
from io import StringIO
from nibabel.py3k import asunicode
from nibabel.externals.six.moves import xrange

__author__ = 'Yaroslav Halchenko'
__copyright__ = 'Copyright (c) 2011-2016 Yaroslav Halchenko ' \
                'and NiBabel contributors'
__license__ = 'MIT'


# global verbosity switch
verbose_level = 0
MAX_UNIQUE = 1000  # maximal number of unique values to report for --counts

def _err(msg=None):
    """To return a string to signal "error" in output table"""
    if msg is None:
        msg = 'error'
    return '!' + msg

def verbose(l, msg):
    """Print `s` if `l` is less than the `verbose_level`
    """
    # TODO: consider using nibabel's logger
    if l <= int(verbose_level):
        print("%s%s" % (' ' * l, msg))


def error(msg, exit_code):
    print >> sys.stderr, msg
    sys.exit(exit_code)


def table2string(table, out=None):
    """Given list of lists figure out their common widths and print to out

    Parameters
    ----------
    table : list of lists of strings
      What is aimed to be printed
    out : None or stream
      Where to print. If None -- will print and return string

    Returns
    -------
    string if out was None
    """

    print2string = out is None
    if print2string:
        out = StringIO()

    # equalize number of elements in each row
    nelements_max = \
        len(table) and \
        max(len(x) for x in table)

    for i, table_ in enumerate(table):
        table[i] += [''] * (nelements_max - len(table_))

    # figure out lengths within each column
    atable = np.asarray(table)
    # eat whole entry while computing width for @w (for wide)
    markup_strip = re.compile('^@([lrc]|w.*)')
    col_width = [max([len(markup_strip.sub('', x))
                      for x in column]) for column in atable.T]
    string = ""
    for i, table_ in enumerate(table):
        string_ = ""
        for j, item in enumerate(table_):
            item = str(item)
            if item.startswith('@'):
                align = item[1]
                item = item[2:]
                if align not in ['l', 'r', 'c', 'w']:
                    raise ValueError('Unknown alignment %s. Known are l,r,c' %
                                     align)
            else:
                align = 'c'

            nspacesl = max(ceil((col_width[j] - len(item)) / 2.0), 0)
            nspacesr = max(col_width[j] - nspacesl - len(item), 0)

            if align in ['w', 'c']:
                pass
            elif align == 'l':
                nspacesl, nspacesr = 0, nspacesl + nspacesr
            elif align == 'r':
                nspacesl, nspacesr = nspacesl + nspacesr, 0
            else:
                raise RuntimeError('Should not get here with align=%s' % align)

            string_ += "%%%ds%%s%%%ds " \
                       % (nspacesl, nspacesr) % ('', item, '')
        string += string_.rstrip() + '\n'
    out.write(asunicode(string))

    if print2string:
        value = out.getvalue()
        out.close()
        return value


def ap(l, format_, sep=', '):
    """Little helper to enforce consistency"""
    if l == '-':
        return l
    ls = [format_ % x for x in l]
    return sep.join(ls)


def safe_get(obj, name):
    """A getattr which would return '-' if getattr fails
    """
    try:
        f = getattr(obj, 'get_' + name)
        return f()
    except Exception as e:
        verbose(2, "get_%s() failed -- %s" % (name, e))
        return '-'


def get_opt_parser():
    # use module docstring for help output
    p = OptionParser(
        usage="%s [OPTIONS] [FILE ...]\n\n" % sys.argv[0] + __doc__,
        version="%prog " + nib.__version__)

    p.add_options([
        Option("-v", "--verbose", action="count",
               dest="verbose", default=0,
               help="Make more noise.  Could be specified multiple times"),

        Option("-H", "--header-fields",
               dest="header_fields", default='',
               help="Header fields (comma separated) to be printed as well (if present)"),

        Option("-s", "--stats",
               action="store_true", dest='stats', default=False,
               help="Output basic data statistics"),

        Option("-c", "--counts",
               action="store_true", dest='counts', default=False,
               help="Output counts - number of entries for each numeric value "
                    "(useful for int ROI maps)"),

        Option("--all-counts",
               action="store_true", dest='all_counts', default=False,
               help="Output all counts, even if number of unique values > %d" % MAX_UNIQUE),

        Option("-z", "--zeros",
               action="store_true", dest='stats_zeros', default=False,
               help="Include zeros into output basic data statistics (--stats, --counts)"),
    ])

    return p


def proc_file(f, opts):
    verbose(1, "Loading %s" % f)

    row = ["@l%s" % f]
    try:
        vol = nib.load(f)
        h = vol.header
    except Exception as e:
        row += ['failed']
        verbose(2, "Failed to gather information -- %s" % str(e))
        return row

    row += [str(safe_get(h, 'data_dtype')),
            '@l[%s]' % ap(safe_get(h, 'data_shape'), '%3g'),
            '@l%s' % ap(safe_get(h, 'zooms'), '%.2f', 'x')]
    # Slope
    if hasattr(h, 'has_data_slope') and \
            (h.has_data_slope or h.has_data_intercept) and \
            not h.get_slope_inter() in [(1.0, 0.0), (None, None)]:
        row += ['@l*%.3g+%.3g' % h.get_slope_inter()]
    else:
        row += ['']

    if hasattr(h, 'extensions') and len(h.extensions):
        row += ['@l#exts: %d' % len(h.extensions)]
    else:
        row += ['']

    if opts.header_fields:
        # signals "all fields"
        if opts.header_fields == 'all':
            # TODO: might vary across file types, thus prior sensing
            # would be needed
            header_fields = h.keys()
        else:
            header_fields = opts.header_fields.split(',')

        for f in header_fields:
            if not f:  # skip empty
                continue
            try:
                row += [str(h[f])]
            except (KeyError, ValueError):
                row += [_err()]

    try:
        if (hasattr(h, 'get_qform') and hasattr(h, 'get_sform') and
                (h.get_qform() != h.get_sform()).any()):
            row += ['sform']
        else:
            row += ['']
    except Exception as e:
        verbose(2, "Failed to obtain qform or sform -- %s" % str(e))
        if isinstance(h, nib.AnalyzeHeader):
            row += ['']
        else:
            row += [_err()]

    if opts.stats or opts.counts:
        # We are doomed to load data
        try:
            d = vol.get_data()
            if not opts.stats_zeros:
                d = d[np.nonzero(d)]
            else:
                # at least flatten it -- functionality below doesn't
                # depend on the original shape, so let's use a flat view
                d = d.reshape(-1)
            if opts.stats:
                # just # of elements
                row += ["@l[%d]" % np.prod(d.shape)]
                # stats
                row += [len(d) and '@l[%.2g, %.2g]' % (np.min(d), np.max(d)) or '-']
            if opts.counts:
                items, inv = np.unique(d, return_inverse=True)
                if len(items) > 1000 and not opts.all_counts:
                    counts = _err("%d uniques. Use --all-counts" % len(items))
                else:
                    freq = np.bincount(inv)
                    counts = " ".join("%g:%d" % (i, f) for i, f in zip(items, freq))
                row += ["@l" + counts]
        except IOError as e:
            verbose(2, "Failed to obtain stats/counts -- %s" % str(e))
            row += [_err()]
    return row


def main():
    """Show must go on"""

    parser = get_opt_parser()
    (opts, files) = parser.parse_args()

    global verbose_level
    verbose_level = opts.verbose

    if verbose_level < 3:
        # suppress nibabel format-compliance warnings
        nib.imageglobals.logger.level = 50

    rows = [proc_file(f, opts) for f in files]

    print(table2string(rows))


if __name__ == '__main__':
    main()
