#!/usr/bin/python3

import argparse
import sys
import curses
import errno
import json
import signal
import time

from collections import OrderedDict
from datetime import datetime
from enum import Enum, unique

import rados


class FSTopException(Exception):
    def __init__(self, msg=''):
        self.error_msg = msg

    def get_error_msg(self):
        return self.error_msg


@unique
class MetricType(Enum):
    METRIC_TYPE_NONE = 0
    METRIC_TYPE_PERCENTAGE = 1
    METRIC_TYPE_LATENCY = 2


FS_TOP_PROG_STR = 'cephfs-top'

# version match b/w fstop and stats emitted by mgr/stats
FS_TOP_SUPPORTED_VER = 1

ITEMS_PAD_LEN = 1
ITEMS_PAD = "  " * ITEMS_PAD_LEN

# metadata provided by mgr/stats
FS_TOP_MAIN_WINDOW_COL_CLIENT_ID = "CLIENT_ID"
FS_TOP_MAIN_WINDOW_COL_MNT_ROOT = "MOUNT_ROOT"
FS_TOP_MAIN_WINDOW_COL_MNTPT_HOST_ADDR = "MOUNT_POINT@HOST/ADDR"

MAIN_WINDOW_TOP_LINE_ITEMS_START = [ITEMS_PAD,
                                    FS_TOP_MAIN_WINDOW_COL_CLIENT_ID,
                                    FS_TOP_MAIN_WINDOW_COL_MNT_ROOT]
MAIN_WINDOW_TOP_LINE_ITEMS_END = [FS_TOP_MAIN_WINDOW_COL_MNTPT_HOST_ADDR]

# adjust this map according to stats version and maintain order
# as emitted by mgr/stast
MAIN_WINDOW_TOP_LINE_METRICS = OrderedDict([
    ("CAP_HIT", MetricType.METRIC_TYPE_PERCENTAGE),
    ("READ_LATENCY", MetricType.METRIC_TYPE_LATENCY),
    ("WRITE_LATENCY", MetricType.METRIC_TYPE_LATENCY),
    ("METADATA_LATENCY", MetricType.METRIC_TYPE_LATENCY),
    ("DENTRY_LEASE", MetricType.METRIC_TYPE_PERCENTAGE),
])
MGR_STATS_COUNTERS = list(MAIN_WINDOW_TOP_LINE_METRICS.keys())

FS_TOP_VERSION_HEADER_FMT = '{prog_name} - {now}'
FS_TOP_CLIENT_HEADER_FMT = 'Client(s): {num_clients} - {num_mounts} FUSE, '\
    '{num_kclients} kclient, {num_libs} libcephfs'

CLIENT_METADATA_KEY = "client_metadata"
CLIENT_METADATA_MOUNT_POINT_KEY = "mount_point"
CLIENT_METADATA_MOUNT_ROOT_KEY = "root"
CLIENT_METADATA_IP_KEY = "IP"
CLIENT_METADATA_HOSTNAME_KEY = "hostname"

GLOBAL_METRICS_KEY = "global_metrics"
GLOBAL_COUNTERS_KEY = "global_counters"


def calc_perc(c):
    if c[0] == 0 and c[1] == 0:
        return 0.0
    return round((c[0] / (c[0] + c[1])) * 100, 2)


def calc_lat(c):
    return round(c[0] + c[1] / 1000000000, 2)


def wrap(s, sl):
    """return a '+' suffixed wrapped string"""
    if len(s) < sl:
        return s
    return f'{s[0:sl-1]}+'


class FSTop(object):
    def __init__(self, args):
        self.rados = None
        self.stop = False
        self.stdscr = None  # curses instance
        self.client_name = args.id
        self.cluster_name = args.cluster
        self.conffile = args.conffile

    def handle_signal(self, signum, _):
        self.stop = True

    def init(self):
        try:
            if self.conffile:
                r_rados = rados.Rados(rados_id=self.client_name, clustername=self.cluster_name,
                                      conffile=self.conffile)
            else:
                r_rados = rados.Rados(rados_id=self.client_name, clustername=self.cluster_name)
            r_rados.conf_read_file()
            r_rados.connect()
            self.rados = r_rados
        except rados.Error as e:
            if e.errno == errno.ENOENT:
                raise FSTopException(f'cluster {self.cluster_name} does not exist')
            else:
                raise FSTopException(f'error connecting to cluster: {e}')
        self.verify_perf_stats_support()
        signal.signal(signal.SIGTERM, self.handle_signal)
        signal.signal(signal.SIGINT, self.handle_signal)

    def fini(self):
        if self.rados:
            self.rados.shutdown()
            self.rados = None

    def selftest(self):
        stats_json = self.perf_stats_query()
        if not stats_json['version'] == FS_TOP_SUPPORTED_VER:
            raise FSTopException('perf stats version mismatch!')

    def setup_curses(self):
        self.stdscr = curses.initscr()

        # coordinate constants for windowing -- (height, width, y, x)
        # NOTE: requires initscr() call before accessing COLS, LINES.
        HEADER_WINDOW_COORD = (2, curses.COLS - 1, 0, 0)
        TOPLINE_WINDOW_COORD = (1, curses.COLS - 1, 3, 0)
        MAIN_WINDOW_COORD = (curses.LINES - 4, curses.COLS - 1, 4, 0)

        self.header = curses.newwin(*HEADER_WINDOW_COORD)
        self.topl = curses.newwin(*TOPLINE_WINDOW_COORD)
        self.mainw = curses.newwin(*MAIN_WINDOW_COORD)
        curses.wrapper(self.display)

    def verify_perf_stats_support(self):
        mon_cmd = {'prefix': 'mgr module ls', 'format': 'json'}
        try:
            ret, buf, out = self.rados.mon_command(json.dumps(mon_cmd), b'')
        except Exception as e:
            raise FSTopException(f'error checking \'stats\' module: {e}')
        if ret != 0:
            raise FSTopException(f'error checking \'stats\' module: {out}')
        if 'stats' not in json.loads(buf.decode('utf-8'))['enabled_modules']:
            raise FSTopException('\'stats\' module not enabled. Use \'ceph mgr module '
                                 'enable stats\' to enable')

    def perf_stats_query(self):
        mgr_cmd = {'prefix': 'fs perf stats', 'format': 'json'}
        try:
            ret, buf, out = self.rados.mgr_command(json.dumps(mgr_cmd), b'')
        except Exception as e:
            raise FSTopException(f'error in \'perf stats\' query: {e}')
        if ret != 0:
            raise FSTopException(f'error in \'perf stats\' query: {out}')
        return json.loads(buf.decode('utf-8'))

    def mtype(self, typ):
        if typ == MetricType.METRIC_TYPE_PERCENTAGE:
            return "(%)"
        elif typ == MetricType.METRIC_TYPE_LATENCY:
            return "(s)"
        else:
            return ''

    def refresh_top_line_and_build_coord(self):
        xp = 0
        x_coord_map = {}

        heading = []
        for item in MAIN_WINDOW_TOP_LINE_ITEMS_START:
            heading.append(item)
            nlen = len(item) + len(ITEMS_PAD)
            x_coord_map[item] = (xp, nlen)
            xp += nlen

        for item, typ in MAIN_WINDOW_TOP_LINE_METRICS.items():
            it = f'{item}{self.mtype(typ)}'
            heading.append(it)
            nlen = len(it) + len(ITEMS_PAD)
            x_coord_map[item] = (xp, nlen)
            xp += nlen

        for item in MAIN_WINDOW_TOP_LINE_ITEMS_END:
            heading.append(item)
            nlen = len(item) + len(ITEMS_PAD)
            x_coord_map[item] = (xp, nlen)
            xp += nlen
        self.topl.addstr(0, 0, ITEMS_PAD.join(heading), curses.A_STANDOUT | curses.A_BOLD)
        return x_coord_map

    def refresh_client(self, client_id, metrics, counters, client_meta, x_coord_map, y_coord):
        for item in MAIN_WINDOW_TOP_LINE_ITEMS_END:
            coord = x_coord_map[item]
            if item == FS_TOP_MAIN_WINDOW_COL_MNTPT_HOST_ADDR:
                self.mainw.addstr(y_coord, coord[0],
                                  f'{client_meta[CLIENT_METADATA_MOUNT_POINT_KEY]}@'
                                  f'{client_meta[CLIENT_METADATA_HOSTNAME_KEY]}/'
                                  f'{client_meta[CLIENT_METADATA_IP_KEY]}')
        for item in MAIN_WINDOW_TOP_LINE_ITEMS_START:
            coord = x_coord_map[item]
            hlen = coord[1] - len(ITEMS_PAD)
            if item == FS_TOP_MAIN_WINDOW_COL_CLIENT_ID:
                self.mainw.addstr(y_coord, coord[0],
                                  wrap(client_id.split('.')[1], hlen))
            elif item == FS_TOP_MAIN_WINDOW_COL_MNT_ROOT:
                self.mainw.addstr(y_coord, coord[0],
                                  wrap(client_meta[CLIENT_METADATA_MOUNT_ROOT_KEY], hlen))
        cidx = 0
        for item in counters:
            coord = x_coord_map[item]
            m = metrics[cidx]
            typ = MAIN_WINDOW_TOP_LINE_METRICS[MGR_STATS_COUNTERS[cidx]]
            if item.lower() in client_meta['valid_metrics']:
                if typ == MetricType.METRIC_TYPE_PERCENTAGE:
                    self.mainw.addstr(y_coord, coord[0], f'{calc_perc(m)}')
                elif typ == MetricType.METRIC_TYPE_LATENCY:
                    self.mainw.addstr(y_coord, coord[0], f'{calc_lat(m)}')
            else:
                self.mainw.addstr(y_coord, coord[0], "N/A")
            cidx += 1

    def refresh_clients(self, x_coord_map, stats_json):
        counters = [m.upper() for m in stats_json[GLOBAL_COUNTERS_KEY]]
        y_coord = 0
        for client_id, metrics in stats_json[GLOBAL_METRICS_KEY].items():
            self.refresh_client(client_id,
                                metrics,
                                counters,
                                stats_json[CLIENT_METADATA_KEY][client_id],
                                x_coord_map,
                                y_coord)
            y_coord += 1

    def refresh_main_window(self, x_coord_map, stats_json):
        self.refresh_clients(x_coord_map, stats_json)

    def refresh_header(self, stats_json):
        if not stats_json['version'] == FS_TOP_SUPPORTED_VER:
            self.header.addstr(0, 0, 'perf stats version mismatch!')
            return False
        client_metadata = stats_json[CLIENT_METADATA_KEY]
        num_clients = len(client_metadata)
        num_mounts = len([client for client, metadata in client_metadata.items() if not
                          metadata[CLIENT_METADATA_MOUNT_POINT_KEY] == 'N/A'])
        num_kclients = len([client for client, metadata in client_metadata.items() if
                            "kernel_version" in metadata])
        num_libs = num_clients - (num_mounts + num_kclients)
        now = datetime.now().ctime()
        self.header.addstr(0, 0,
                           FS_TOP_VERSION_HEADER_FMT.format(prog_name=FS_TOP_PROG_STR, now=now),
                           curses.A_STANDOUT | curses.A_BOLD)
        self.header.addstr(1, 0, FS_TOP_CLIENT_HEADER_FMT.format(num_clients=num_clients,
                                                                 num_mounts=num_mounts,
                                                                 num_kclients=num_kclients,
                                                                 num_libs=num_libs))
        return True

    def display(self, _):
        x_coord_map = self.refresh_top_line_and_build_coord()
        self.topl.refresh()
        while not self.stop:
            stats_json = self.perf_stats_query()
            self.header.clear()
            self.mainw.clear()
            if self.refresh_header(stats_json):
                self.refresh_main_window(x_coord_map, stats_json)
            self.header.refresh()
            self.mainw.refresh()
            time.sleep(1)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Ceph Filesystem top utility')
    parser.add_argument('--cluster', nargs='?', const='ceph', default='ceph',
                        help='Ceph cluster to connect (defualt: ceph)')
    parser.add_argument('--id', nargs='?', const='fstop', default='fstop',
                        help='Ceph user to use to connection (default: fstop)')
    parser.add_argument('--conffile', nargs='?', default=None,
                        help='Path to cluster configuration file')
    parser.add_argument('--selftest', dest='selftest', action='store_true',
                        help='run in selftest mode')
    args = parser.parse_args()
    err = False
    ft = FSTop(args)
    try:
        ft.init()
        if args.selftest:
            ft.selftest()
            sys.stdout.write("selftest ok\n")
        else:
            ft.setup_curses()
    except FSTopException as fst:
        err = True
        sys.stderr.write(f'{fst.get_error_msg()}\n')
    except Exception as e:
        err = True
        sys.stderr.write(f'exception: {e}\n')
    finally:
        ft.fini()
    sys.exit(0 if not err else -1)
