#!/usr/bin/env python3

import datetime
import logging
import pwd
from optparse import OptionParser
import os
import re
from subprocess import Popen, STDOUT
import sys

import configparser


CONFIG_PATH = "/etc/endpoints.ini"
LOG_DIR = "/var/log/update-remote-wn-client"


log = logging.getLogger(__name__)


class Error(Exception):
    pass


def check_section(cfg, section):
    """Verify a single section of the config file

    """
    assert cfg.has_section(section), "%s missing - this should not happen" % section
    m = re.match(r"Endpoint\s+([^\s:\\/]+)\s*$", section)
    if not m:
        raise Error(
            "%s: Endpoint name can't contain whitespace, `\\`, `/`, or `:`" % section
        )

    def safe_get(option):
        try:
            return cfg.get(section, option)
        except configparser.NoOptionError:
            return

    local_user = safe_get("local_user")
    remote_host = safe_get("remote_host")
    timeout = safe_get("timeout")
    try:
        if timeout:
            timeout = int(timeout)
    except ValueError:
        raise Error("%s: Invalid timeout %s" % (section, timeout))

    if not local_user:
        raise Error("%s: Missing required option local_user" % section)
    if not remote_host:
        raise Error("%s: Missing required option remote_host" % section)
    try:
        pwd.getpwnam(local_user)
    except KeyError:
        raise Error("%s: local_user %s does not exist" % (section, local_user))


def call_updater(cfg, section, updater_script, log_dir, dry_run=False):
    """Call a single instance of the updater for a config section.
    Return True on success and False on failure.

    """
    m = re.match(r"Endpoint\s+([^\s:\\/]+)\s*$", section)
    assert m, "Invalid chars in %s - this should have been caught" % section
    endpoint = m.group(1)

    def safe_get(option):
        try:
            return cfg.get(section, option)
        except configparser.NoOptionError:
            return

    local_user = safe_get("local_user")
    remote_host = safe_get("remote_host")
    upstream_url = safe_get("upstream_url")
    remote_user = safe_get("remote_user")
    remote_dir = safe_get("remote_dir")
    # Have default timeout be nonzero
    timeout = safe_get("timeout")
    if timeout is None:
        timeout = "3600"
    if timeout:
        timeout = int(timeout)
    else:
        timeout = 0

    assert local_user, "local_user not specified - this should have been caught"
    assert remote_host, "remote_host not specified - this should have been caught"
    assert pwd.getpwnam(local_user), "nonexistant user - this should have been caught"

    cmd = [updater_script, remote_host]
    if upstream_url:
        cmd.append("--upstream-url=%s" % upstream_url)
    if remote_user:
        cmd.append("--remote-user=%s" % remote_user)
    if remote_dir:
        cmd.append("--remote-dir=%s" % remote_dir)

    timeout_cmd = []
    if timeout:
        timeout_cmd = ["timeout", "--kill-after=60", str(timeout)]

    if dry_run:
        if timeout:
            print("Using %r" % timeout_cmd)
        print("Would run %r as %s" % (cmd, local_user))
        return

    log_file_path = os.path.join(log_dir, endpoint)
    with open(log_file_path, "ab") as log_file:
        log_file.write(b"-" * 79 + b"\n")
        now = datetime.datetime.now()
        log_file.write(
            b"Started at %s\nUser: %s\nCommand: %r\n\n"
            % (now.strftime("%F %T").encode('latin-1'),
               local_user.encode('latin-1'),
               cmd)
        )
        log_file.flush()
        proc = Popen(
            ["sudo", "-n", "-H", "-u", local_user] + timeout_cmd + cmd,
            stdout=log_file,
            stderr=STDOUT,
            cwd="/",
        )
        returncode = proc.wait()

    if returncode == 0:
        log.info("Endpoint %s ok.", endpoint)
        return True
    elif returncode == 124:
        log.error(
            "Endpoint %s TIMED OUT.  See %s for details.", endpoint, log_file_path
        )
        return False
    else:
        log.error(
            "Endpoint %s FAILED with error %d.  See %s for details.",
            endpoint,
            returncode,
            log_file_path,
        )
        return False


def which(executable):
    for d in os.environ["PATH"].split(":"):
        test_path = os.path.join(d, executable)
        if os.path.exists(test_path) and os.access(test_path, os.X_OK):
            return os.path.abspath(test_path)


def main():
    parser = OptionParser()
    parser.add_option(
        "--config",
        default=CONFIG_PATH,
        help="Location of config file. [default: %default]",
    )
    parser.add_option(
        "--log-dir",
        default=LOG_DIR,
        help="Location of log directory. [default: %default]",
    )
    parser.add_option(
        "-n",
        "--dry-run",
        action="store_true",
        help="Don't change anything, just print the commands that would be run",
    )
    opts, _ = parser.parse_args()

    cfg = configparser.ConfigParser()
    cfg.read(opts.config)

    sections_list = list(x for x in cfg.sections() if x.startswith("Endpoint"))
    if not sections_list:
        log.error("No Endpoint sections found")
        return 1

    try:
        for section in sections_list:
            check_section(cfg, section)
    except Error as e:
        log.error(e)
        return 1

    updater_script = which("update-remote-wn-client")
    if not updater_script:
        msg = "update-remote-wn-client not found in PATH"
        if opts.dry_run:
            log.warning(msg)
            updater_script = "update-remote-wn-client"
        else:
            log.error(msg)
            return 1

    ret = 0
    for section in sections_list:
        ok = call_updater(cfg, section, updater_script, opts.log_dir, opts.dry_run)
        if not ok:
            ret = 1

    return ret


if __name__ == "__main__":
    logging.basicConfig(format="*** %(message)s", level=logging.INFO)
    sys.exit(main())
