#! /usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import re
import shutil
import signal
import sqlite3
import subprocess as sp
import sys

from datetime import datetime

from getpass import getpass

from nacl.encoding import Base64Encoder

from globaleaks.db import get_db_file
from globaleaks.orm import make_db_uri, get_engine
from globaleaks.rest.requests import AdminNotificationDesc, AdminNodeDesc
from globaleaks.settings import Settings
from globaleaks.utils.crypto import GCE, generateRandomPassword
from globaleaks.utils.utility import datetime_now


# Function to check if input is a number
def is_number(value):
    try:
        # Try converting the value to a float
        int(value)
        return True
    except ValueError:
        return False


def check_file(f):
    if not os.path.isfile(f) or not os.access(f, os.R_OK):
        raise RuntimeError("Missing or inaccessible file: {}".format(f))


def check_dir(d):
    if not os.path.isdir(d):
        raise RuntimeError("Missing or inaccessible dir: {}".format(d))


def check_db(d):
    check_dir(args.workdir)

    db_version, db_path = get_db_file(args.workdir)

    if db_version <= 0:
        return

    check_file(db_path)

    return db_path


def default_backup_path():
    t = datetime.now().strftime("%y_%m_%d")
    name = "globaleaks_backup_{}.tar.gz".format(t)
    return os.path.join(os.getcwd(), name)


def get_gl_pid():
    try:
        with open(Settings.pidfile_path, 'r') as fd:
            return int(fd.read())
    except:
        return 0


def send_gl_signal(sig):
    try:
        pid = get_gl_pid()
        if (pid):
            os.kill(pid, sig)
            return True
    except:
        return False


def is_gl_running():
    return send_gl_signal(0)


def reset_gl_cache():
    return send_gl_signal(signal.SIGUSR1)


def backup(args):
    workdir = args.workdir
    check_dir(workdir)

    must_stop = is_gl_running()

    if must_stop: sp.check_call(["systemctl", "stop", "globaleaks"])
    print("Creating an archive backup of the globaleaks setup. . .")
    p_head, p_tail = os.path.split(args.workdir)

    sp.check_call(["tar", "-zcf", args.backuppath, "--exclude='backups'", "-C", args.workdir, '.'])

    if must_stop: sp.check_call(["systemctl", "start", "globaleaks"])
    print("Success: The archived backup was created at:", args.backuppath)


def restore(args):
    check_dir(args.workdir)

    check_file(args.backuppath)

    print("\n", "-"*72)
    print("WARNING this command will DELETE everything currently in {}".format(args.workdir))
    print("-"*72)
    ans = input("Are you sure that you want to continue? [y/n] ")
    if not ans == "y":
        sys.exit(0)
    print("-"*72)

    must_stop = is_gl_running()
    if must_stop: sp.check_call(["systemctl", "stop", "globaleaks"])
    print("Deleting {} . . .".format(args.workdir))

    p_head, p_tail = os.path.split(args.workdir)

    shutil.rmtree(args.workdir)
    os.makedirs(args.workdir)

    print("Extracting the archive {}".format(args.backuppath))
    sp.check_call(["tar", "-xf", args.backuppath, "-C", args.workdir])

    if must_stop: sp.check_call(["systemctl", "start", "globaleaks"])

    print("Success! globaleaks has been restored from a backup")


def reset_password(args):
    db_path = check_db(args.workdir)

    admin_username = input("Username: ")
    admin_password = getpass()

    conn = sqlite3.connect(db_path)
    c = conn.cursor()

    password = generateRandomPassword(16)
    user_salt = GCE.generate_salt()
    user_enc_key, user_hash = GCE.calculate_key_and_hash(password, user_salt)

    # The operator authenticates as an administrator of the root tenant (tid 1):
    # recovery relies on the root tenant escrow key, which is the only escrow key
    # reachable from the CLI and the one the user account keys are backed up to
    # (crypto_escrow_bkp1_key).
    QUERY = "SELECT id, salt, hash, crypto_prv_key, crypto_escrow_prv_key FROM user WHERE tid=? AND username=?"
    c.execute(QUERY, (1, admin_username))

    admin_user = c.fetchone()
    if admin_user is None:
        print("Failed! The specified admin user '{}' does not exist".format(admin_username))
        sys.exit(1)

    admin_id, admin_salt, admin_hash, admin_crypto_prv_key, admin_crypto_escrow_prv_key = admin_user[0], admin_user[1], admin_user[2], admin_user[3], admin_user[4]

    _, check_hash = GCE.calculate_key_and_hash(admin_password, admin_salt)
    if not GCE.check_equality(admin_hash, check_hash):
        print("Failed! Invalid password")
        sys.exit(1)

    admin_enc_key = GCE.derive_key(admin_password.encode(), admin_salt)

    admin_ek = None
    if admin_crypto_prv_key:
        try:
            admin_cc = GCE.symmetric_decrypt(admin_enc_key, Base64Encoder.decode(admin_crypto_prv_key))
        except:
            print("Failed! Invalid password")
            sys.exit(1)

        if admin_crypto_escrow_prv_key:
            admin_ek = GCE.asymmetric_decrypt(admin_cc, Base64Encoder.decode(admin_crypto_escrow_prv_key))

    QUERY = "SELECT id, crypto_pub_key, crypto_escrow_bkp1_key FROM user WHERE tid=? AND username=?;"
    c.execute(QUERY, (args.tid, args.username))

    user = c.fetchone()

    if user is None:
        print("Failed! The user '{}' does not exist".format(args.username))
        sys.exit(1)

    user_id, user_crypto_pub_key, user_crypto_escrow_bkp1_key = user[0], user[1], user[2]

    if user_crypto_escrow_bkp1_key and admin_ek is not None:
        # The account key is recoverable from the root tenant escrow backup:
        # re-wrap it under the new password.
        user_cc = GCE.asymmetric_decrypt(admin_ek, Base64Encoder.decode(user_crypto_escrow_bkp1_key))
        user_crypto_prv_key = Base64Encoder.encode(GCE.symmetric_encrypt(user_enc_key, user_cc))
    elif not user_crypto_pub_key:
        # The user has no encryption enabled: a plain password reset is safe.
        user_crypto_prv_key = ''
    else:
        # The user has encryption enabled but the operator holds no escrow key
        # able to recover the account key. Overwriting crypto_prv_key here would
        # irreversibly destroy the key and every report the user can decrypt, so
        # refuse the reset instead of silently causing data loss.
        print("Failed! The user '{}' has encryption enabled and the account key "
              "cannot be recovered without access to the root tenant escrow key; "
              "the password has not been reset to avoid permanent data loss".format(args.username))
        sys.exit(1)

    QUERY = "UPDATE user SET salt=?, hash=?, crypto_prv_key=?, password_change_date=?, password_change_needed=? WHERE id=?;"
    c.execute(QUERY, (user_salt, user_hash, user_crypto_prv_key, datetime_now(), True, user_id))

    QUERY = "INSERT INTO auditlog(tid, date, type, user_id, object_id) VALUES(?,?,?,?,?);"
    c.execute(QUERY, (args.tid, datetime_now(), 'change_password', admin_id, user_id))

    conn.commit()
    conn.close()

    print(("=========================================\n"
           "||      Password reset completed       ||\n"
           "=========================================\n"
           " Username: {}\n"
           " Password: {}\n"
           "=========================================\n"
         ).format(args.username, password))


def disable_2fa(args):
    db_path = check_db(args.workdir)

    conn = sqlite3.connect(db_path)
    c = conn.cursor()

    admin_username = input("Username: ")
    admin_password = getpass()

    QUERY = "SELECT id, salt, hash FROM user WHERE tid=? AND username=?"
    c.execute(QUERY, (args.tid, admin_username))

    admin_user = c.fetchone()
    if admin_user is None:
        print("Failed! The specified admin user '{}' does not exist".format(admin_username))
        sys.exit(1)

    admin_id, admin_salt, admin_hash = admin_user[0], admin_user[1], admin_user[2]

    _, check_hash = GCE.calculate_key_and_hash(admin_password, admin_salt)
    if not GCE.check_equality(admin_hash, check_hash):
        print("Failed! Invalid password")
        sys.exit(1)

    QUERY = "SELECT id FROM user WHERE tid=? AND username=?;"
    c.execute(QUERY, (args.tid, args.username))

    user = c.fetchone()

    if user is None:
        print("Failed! The user '{}' does not exist".format(args.username))
        sys.exit(1)

    user_id = user[0]

    QUERY = "UPDATE user SET two_factor_secret=? WHERE id=?;"
    c.execute(QUERY, ('', user_id))

    QUERY = "INSERT INTO auditlog(tid, date, type, user_id, object_id) VALUES(?,?,?,?,?);"
    c.execute(QUERY, (args.tid, datetime_now(), 'disable_2fa', admin_id, user_id))

    conn.commit()
    conn.close()

    print(("=========================================\n"
           "||         2fa disabled                ||\n"
           "=========================================\n"
           " Username: {}\n"
           "=========================================\n"
         ).format(args.username))


def get_var(args):
    db_path = check_db(args.workdir)

    QUERY = "SELECT value FROM config WHERE var_name=? AND tid=?;"

    conn = sqlite3.connect(db_path)
    c = conn.cursor()
    c.execute(QUERY, (args.varname, args.tid))
    ret = c.fetchone()
    if ret is None:
        print("Failed to read value of var '{}'.".format(args.varname))
        sys.exit(1)

    conn.close()

    print(json.loads(str(ret[0])))


def set_var(args, silent=False):
    db_path = check_db(args.workdir)

    if args.value.lower() == 'true':
        args.value = True
    elif args.value.lower() == 'false':
        args.value = False
    elif is_number(args.value):
        args.value = int(args.value)

    value = json.dumps(args.value)

    QUERY = "UPDATE config SET value=? WHERE tid=? AND var_name=?;"

    conn = sqlite3.connect(db_path)
    c = conn.cursor()
    c.execute(QUERY, (value, args.tid, args.varname))
    conn.commit()
    conn.close()

    reset_gl_cache()

    if not silent:
        print("Success! {} set to '{}'".format(args.varname, args.value))


def set_user_protection(args, protect):
    db_path = check_db(args.workdir)

    conn = sqlite3.connect(db_path)
    c = conn.cursor()

    c.execute("SELECT id FROM user WHERE tid=? AND username=?;", (args.tid, args.username))
    user = c.fetchone()
    if user is None:
        print("Failed! The user '{}' does not exist".format(args.username))
        conn.close()
        sys.exit(1)

    user_id = user[0]

    # protected_users stores the list of protected user ids; ids are used in
    # place of usernames so that the protection survives a username change.
    c.execute("SELECT value FROM config WHERE tid=? AND var_name='protected_users';", (args.tid,))
    ret = c.fetchone()
    protected = json.loads(ret[0]) if ret is not None and ret[0] else []

    if protect:
        if user_id not in protected:
            protected.append(user_id)
    else:
        protected = [u for u in protected if u != user_id]

    c.execute("UPDATE config SET value=? WHERE tid=? AND var_name='protected_users';",
              (json.dumps(protected), args.tid))
    conn.commit()
    conn.close()

    reset_gl_cache()

    print("Success! User '{}' has been {}.".format(args.username,
          "protected" if protect else "unprotected"))


def protect_user(args):
    set_user_protection(args, True)


def unprotect_user(args):
    set_user_protection(args, False)


def add_workingdir_path_arg(parser):
    parser.add_argument("-w",
                        "--workdir",
                        help="the directory that hosts the globaleaks data storage",
                        default=Settings.working_path)

Settings.eval_paths()

parser = argparse.ArgumentParser(prog="gl-admin",
                 description="GlobaLeaks backend administator interface")

subp = parser.add_subparsers(title="commands")

bck_p = subp.add_parser("backup", help="create a backup of the setup")
add_workingdir_path_arg(bck_p)
bck_p.add_argument("backuppath", nargs="?", help="the path and name of the backup",
                   default=default_backup_path())
bck_p.set_defaults(func=backup)

res_p = subp.add_parser("restore", help="restore a backup of the setup")
add_workingdir_path_arg(res_p)
res_p.add_argument("backuppath", nargs="?", help="the path and name of the backup",
                   default=default_backup_path())
res_p.set_defaults(func=restore)

pw_p = subp.add_parser("reset-password", help="reset the password of a user")
add_workingdir_path_arg(pw_p)
pw_p.add_argument("--tid", help="the tenant id", default='1', type=int)
pw_p.add_argument("username")
pw_p.set_defaults(func=reset_password)

pw_p = subp.add_parser("disable-2fa", help="disable the two factor authentication of a user")
add_workingdir_path_arg(pw_p)
pw_p.add_argument("--tid", help="the tenant id", default='1', type=int)
pw_p.add_argument("username")
pw_p.set_defaults(func=disable_2fa)

rv_p = subp.add_parser("getvar", help="get database config variable")
add_workingdir_path_arg(rv_p)
rv_p.add_argument("--tid", help="the tenant id", default='1', type=int)
rv_p.add_argument("varname", help="the name of the config var", default='version', type=str)
rv_p.set_defaults(func=get_var)

sv_p = subp.add_parser("setvar", help="set database config variable")
add_workingdir_path_arg(sv_p)
sv_p.add_argument("--tid", help="the tenant id", default='1', type=int)
sv_p.add_argument("varname", help="the name of the config var", type=str)
sv_p.add_argument("value", help="value which must be of the correct type Bool(0|1), Int(0-9^9), String(everything else)")
sv_p.set_defaults(func=set_var)

pr_p = subp.add_parser("protect-user", help="protect a user from deletion and password reset")
add_workingdir_path_arg(pr_p)
pr_p.add_argument("--tid", help="the tenant id", default='1', type=int)
pr_p.add_argument("username")
pr_p.set_defaults(func=protect_user)

upr_p = subp.add_parser("unprotect-user", help="remove the protection from a user")
add_workingdir_path_arg(upr_p)
upr_p.add_argument("--tid", help="the tenant id", default='1', type=int)
upr_p.add_argument("username")
upr_p.set_defaults(func=unprotect_user)

if __name__ == '__main__':
    try:
        args = parser.parse_args()
        if hasattr(args, 'func'):
            args.func(args)
        else:
            parser.print_help()

    except Exception as exc:
        print("ERROR: {}".format(exc))
