#!/usr/bin/python3
#
# This file is part of Plinth.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

"""
Configuration helper for monkeysphere.
"""

import argparse
import augeas
import json
import os
import psutil
import re
import signal
import subprocess


def parse_arguments():
    """Return parsed command line arguments as dictionary."""
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest='subcommand', help='Sub command')

    host_show_keys = subparsers.add_parser(
        'host-show-keys', help='Show imported/importable keys')
    host_show_keys.add_argument(
        'key_id', nargs='?', help='Optional KEYID to retrieve details for')

    host_import_key = subparsers.add_parser(
        'host-import-key', help='Import a key into monkeysphere')
    host_import_key.add_argument(
        'ssh_fingerprint', help='SSH fingerprint of the key to import')
    host_import_key.add_argument(
        'domains', nargs='*', help='List of available domains')

    host_publish_key = subparsers.add_parser(
        'host-publish-key', help='Push host key to keyserver')
    host_publish_key.add_argument(
        'key_ids', nargs='*', help='Optional list of KEYIDs')

    host_cancel_publish = subparsers.add_parser(
        'host-cancel-publish', help='Cancel a running publish operation')
    host_cancel_publish.add_argument('pid', help='PID of the publish process')

    return parser.parse_args()


def get_ssh_keys(fingerprint_hash):
    """Return all SSH keys."""
    keys = {}

    key_files = ['/etc/ssh/ssh_host_rsa_key']
    for key_file in key_files:
        output = subprocess.check_output(
            ['ssh-keygen', '-l', '-E', fingerprint_hash, '-f', key_file])
        fingerprint = output.decode().split()[1]
        keys[fingerprint] = {'ssh_fingerprint': fingerprint,
                             'service': 'ssh',
                             'key_file': key_file,
                             'available_domains': ['*']}

    return keys


def get_pem_ssh_fingerprint(pem_file, fingerprint_hash):
    """Return the SSH fingerprint of a PEM file."""
    public_key = subprocess.check_output(
        ['openssl', 'rsa', '-in', pem_file, '-pubout'],
        stderr=subprocess.DEVNULL)
    ssh_public_key = subprocess.check_output(
        ['ssh-keygen', '-i', '-m', 'PKCS8', '-f', '/dev/stdin'],
        input=public_key)
    fingerprint = subprocess.check_output(
        ['ssh-keygen', '-l', '-E', fingerprint_hash, '-f', '/dev/stdin'],
        input=ssh_public_key)

    return fingerprint.decode().split()[1]


def get_https_keys(fingerprint_hash):
    """Return all HTTPS keys."""
    aug = augeas.Augeas(flags=augeas.Augeas.NO_LOAD +
                        augeas.Augeas.NO_MODL_AUTOLOAD)
    aug.set('/augeas/load/Httpd/lens', 'Httpd.lns')
    aug.set('/augeas/load/Httpd/incl[last() + 1]',
            '/etc/apache2/sites-available/*')
    aug.load()

    keys = {}
    path = '/files/etc/apache2/sites-available//VirtualHost'
    for match in aug.match(path):
        host = {'available_domains': ['*'], 'service': 'https'}
        for directive in aug.match(match + '/directive'):
            name = aug.get(directive)
            if name == 'ServerName':
                host['available_domains'] = [aug.get(directive + '/arg')]
            elif name in ('GnuTLSKeyFile', 'SSLCertificateKeyFile'):
                host['key_file'] = aug.get(directive + '/arg')

        if 'key_file' in host:
            host['ssh_fingerprint'] = get_pem_ssh_fingerprint(
                host['key_file'], fingerprint_hash)
            keys[host['ssh_fingerprint']] = host

    return keys


def get_monkeysphere_keys(key_id=None):
    """Return the list of keys imported into monkeysphere."""
    try:
        key_ids = [] if not key_id else [key_id]
        output = subprocess.check_output(
            ['monkeysphere-host', 'show-keys'] + key_ids,
            stderr=subprocess.DEVNULL)
    except subprocess.CalledProcessError:
        # no keys available
        return {}

    # parse output
    default_dict = {'imported_domains': [], 'available_domains': []}
    keys = [default_dict.copy()]
    lines = output.decode().strip().split('\n')
    for line in lines:
        if line.startswith('pub'):
            data = line.lstrip('pub').split()
            keys[-1]['pub'] = data[0]
            keys[-1]['date'] = data[1]
        elif line.startswith('uid'):
            uid = line.lstrip('uid').strip()
            if uid.startswith('['):
                uid = uid[uid.index(']') + 1:].strip()

            keys[-1].setdefault('uids', []).append(uid)
            matches = re.match(r'([a-zA-Z]*)://(.*)(:\d*)?', uid)
            keys[-1]['service'] = matches.group(1)
            keys[-1]['imported_domains'].append(matches.group(2))
        elif line.startswith('OpenPGP fingerprint:'):
            keys[-1]['openpgp_fingerprint'] = \
                line.lstrip('OpenPGP fingerprint:')
        elif line.startswith('ssh fingerprint:'):
            data = line.lstrip('ssh fingerprint:').split()
            keys[-1]['ssh_key_size'] = data[0]
            keys[-1]['ssh_fingerprint'] = data[1]
            keys[-1]['ssh_key_type'] = data[2].strip('()')
        elif line == '':
            keys.append(default_dict.copy())

    return {key['ssh_fingerprint']: key for key in keys}


def get_merged_keys(key_id=None):
    """Return merged list of system and monkeysphere keys."""
    keys = get_monkeysphere_keys(key_id)

    # Monkeysphere used use MD5 for fingerprint hash and recently
    # changed to SHA256.  In case of SHA256 the string 'SHA256:' is
    # being prepended to the fingerprint.  Hoping that such a prefix
    # will be available in all future changes, extract it from one key
    # (assuming all the others will be the same) and use it.
    fingerprint_hash = 'SHA256'
    if keys:
        fingerprint_hash = list(keys.keys())[0].split(':')[0]

    system_keys = list(get_ssh_keys(fingerprint_hash).items()) + \
        list(get_https_keys(fingerprint_hash).items())
    for ssh_fingerprint, key in system_keys:
        if key_id and ssh_fingerprint not in keys:
            continue

        if ssh_fingerprint in keys:
            keys[ssh_fingerprint].update({
                'available_domains': key['available_domains'],
                'key_file': key['key_file']})
        else:
            keys[ssh_fingerprint] = key

    return keys


def subcommand_host_show_keys(arguments):
    """Show host key fingerprints."""
    print(json.dumps({'keys': get_merged_keys(arguments.key_id)}))


def subcommand_host_import_key(arguments, second_run=False):
    """Import host SSH key."""
    keys = get_merged_keys()
    if arguments.ssh_fingerprint not in keys:
        raise Exception('Unknown SSH fingerprint')

    key = keys[arguments.ssh_fingerprint]
    if '*' in key['available_domains']:
        key['available_domains'] = arguments.domains

    if 'openpgp_fingerprint' not in key and not second_run:
        env = dict(os.environ, MONKEYSPHERE_PROMPT='false')
        subprocess.check_call(
            ['monkeysphere-host', 'import-key',
             key['key_file'], key['service'] + '://' +
             key['available_domains'][0]], env=env)
        subcommand_host_import_key(arguments, second_run=True)
    else:
        for domain in key['available_domains']:
            if domain in key['imported_domains']:
                continue

            env = dict(os.environ, MONKEYSPHERE_PROMPT='false')
            subprocess.check_call(
                ['monkeysphere-host', 'add-servicename',
                 key['service'] + '://' + domain, key['openpgp_fingerprint']],
                env=env)


def subcommand_host_publish_key(arguments):
    """Push host key to keyserver."""
    # setting TMPDIR as workaround for Debian bug #656750
    proc = subprocess.Popen(
        ['monkeysphere-host', 'publish-keys'] + arguments.key_ids,
        stdout=subprocess.PIPE, stderr=subprocess.PIPE,
        env=dict(
            os.environ,
            TMPDIR='/var/lib/monkeysphere/authentication/tmp/',
            MONKEYSPHERE_PROMPT='false'))
    output, error = proc.communicate()
    output, error = output.decode(), error.decode()
    if proc.returncode != 0:
        raise Exception(output, error)

    print(output)


def subcommand_host_cancel_publish(arguments):
    """Kill a running publish process."""
    process = psutil.Process(int(arguments.pid))

    # Perform tight checks on the process before killing for security.
    arguments = process.cmdline()
    if not arguments:
        # Process already completed
        return

    # Remove the sudo prefix if present
    while arguments[0] == 'sudo' or arguments[0].startswith('-'):
        arguments = arguments[1:]

    if len(arguments) >= 2 and \
       arguments[0].split('/')[-1] == 'monkeysphere' and \
       arguments[1] == 'host-publish-key':
        process.send_signal(signal.SIGTERM)


def main():
    """Parse arguments and perform all duties."""
    arguments = parse_arguments()

    subcommand = arguments.subcommand.replace('-', '_')
    subcommand_method = globals()['subcommand_' + subcommand]
    subcommand_method(arguments)


if __name__ == '__main__':
    main()
