#!/usr/bin/env python3
import os
import sys
import subprocess
import signal
import glob
import argparse
import psutil

MAX_NCPUS = 4096

def get_log_name(args, mon):
    log = os.path.join(args.outdir,
                       args.log + "-procmon-" + mon + "__.log")
    return log

def get_csv_name(args, mon, scope):
    csv = os.path.join(args.outdir,
                       args.log + "-procinsight-" + mon + "-" + scope + ".csv")
    return csv

def transpose_2d_list(ll):
    return list(zip(*ll))

def report_stat_in_csv(col_names, stat_ll, f):
    def get_sep(c, ncol):
        if c == (ncol - 1):
            return "\n"
        else:
            return ", "

    def is_float(n):
        return int(n) != n

    # print column header
    ncol = len(stat_ll[0]) + 1
    col_names = col_names[0:ncol]
    for c, col in enumerate(col_names):
        print("{0:^20}".format(col), end=get_sep(c, ncol), file = f)

    # print tuple by tuple
    ncol = len(stat_ll[0])
    nrow = len(stat_ll)
    for r in range(nrow):
        key, _= stat_ll[r][0]
        print("{0:<20}".format(key), end=", ", file = f)
        for c in range(ncol):
            _, val= stat_ll[r][c]
            if val is None:
                print("{0:>20}".format("%s" % "N/A"), end=get_sep(c, ncol), file = f)
            elif is_float(val):
                print("{0:>20}".format("%.4f" % val), end=get_sep(c, ncol), file = f)
            else:
                print("{0:>20}".format(int(val)), end=get_sep(c, ncol), file = f)

def report_procmon_sched_in_csv(args):
    # get stat
    (sched_sw_stat, sched_core_stat) = procmon_stat_sched(args)
    sched_sw_stat = [sched_sw_stat]
    col_names = ["scope", "system"]

    # generate csv for system-wide state
    csv_name = get_csv_name(args, "sched", "sw")
    with open(csv_name, "w") as f:
        report_stat_in_csv(col_names, sched_sw_stat, f)

        if args.quiet == False:
            print("## Sched_wakeup count", file = sys.stdout)
            report_stat_in_csv(col_names, sched_sw_stat, sys.stdout)
            print("\n")

    # generate csv for per-core state
    csv_name = get_csv_name(args, "sched", "core")
    with open(csv_name, "w") as f:
        report_stat_in_csv(col_names, sched_core_stat, f)

def report_procmon_cstate_in_csv(args):
    # get stat
    (cstate_sw_stat, cstate_core_stat) = procmon_stat_cstate(args)
    cstate_sw_stat = transpose_2d_list( [cstate_sw_stat] )
    cstate_core_stat = transpose_2d_list(cstate_core_stat)

    # generate csv for system-wide state
    col_names = ["info", "system"]
    csv_name = get_csv_name(args, "cstate", "sw")
    with open(csv_name, "w") as f:
        report_stat_in_csv(col_names, cstate_sw_stat, f)

        if args.quiet == False:
            print("## Cstate states", file = sys.stdout)
            report_stat_in_csv(col_names, cstate_sw_stat, sys.stdout)
            print("\n")

    # generate csv for per-core state
    col_names = ["info"] + \
            list( map(lambda c: "CPU" + str(c), range(MAX_NCPUS)) )
    csv_name = get_csv_name(args, "cstate", "core")
    with open(csv_name, "w") as f:
        report_stat_in_csv(col_names, cstate_core_stat, f)

def report_procmon_energy_in_csv(args):
    # get stat
    energy_sw_stat = transpose_2d_list( [procmon_stat_energy(args)] )

    # generate csv for system-wide state
    col_names = ["energy", "system"]
    csv_name = get_csv_name(args, "energy", "sw")
    with open(csv_name, "w") as f:
        report_stat_in_csv(col_names, energy_sw_stat, f)

        if args.quiet == False:
            print("## Energy consumption", file = sys.stdout)
            report_stat_in_csv(col_names, energy_sw_stat, sys.stdout)
            print("\n")

def report_procmon_perf_in_csv(args):
    # get stat
    perf_sw_stat = transpose_2d_list( [procmon_stat_perf(args)] )

    # generate csv for system-wide state
    col_names = ["info", "system"]
    csv_name = get_csv_name(args, "perf", "sw")
    with open(csv_name, "w") as f:
        report_stat_in_csv(col_names, perf_sw_stat, f)

        if args.quiet == False:
            print("## Performance counters", file = sys.stdout)
            report_stat_in_csv(col_names, perf_sw_stat, sys.stdout)
            print("\n")

def report_procmons_in_csv(args):
    # -s, --sched
    report_procmon_sched_in_csv(args)
    # -c, --cstate
    report_procmon_cstate_in_csv(args)
    # -e, --energy
    report_procmon_energy_in_csv(args)
    # -p, --perf
    report_procmon_perf_in_csv(args)

def str_to_nstr(s):
    ns = map(lambda c: c if c != ',' and c != '%' else '', s)
    return "".join( list(ns) )

def procmon_stat_sched(args):
    sw_stat, core_stat = [0], [0]
    log = get_log_name(args, "sched")

    with open(log, 'r') as f:
        def get_kv(line):
            last_wd = line.split()[-1]
            toks = line.split()[-1].split('=')
            if len(toks) != 2:
                return ("", "")
            k, v = line.split()[-1].split('=')
            return (k, v)

        # read the first line, "cpus=16" and do sanity check 
        # then initialize the per-core stat
        while True:
            k, v = get_kv(f.readline())
            if k == "cpus":
                ncpus = int(v)
                break
        core_stat = [0] * ncpus

        # read the rest lines, "... target_cpu=006"
        for line in f:
            k, v = get_kv(line)
            if k != "target_cpu":
                continue
            cpu_id = int(v)
            core_stat[cpu_id] = core_stat[cpu_id] + 1

        # update the system-wide stat
        sw_stat[0] = sum(core_stat)

    # convert to a list of tuples
    sw_stat2 = list( map(lambda v: ("sched_wakeup", v), sw_stat) )
    core_stat2= list( map(lambda kv: [("CPU" + str(kv[0]), kv[1])], \
                          list( zip(range(len(core_stat)), core_stat) )) )
    return sw_stat2, core_stat2
    
def procmon_stat_cstate(args):
    sw_stat, core_stat = [], [[]] * MAX_NCPUS
    log = get_log_name(args, "cstate")

    with open(log, 'r') as f:
        def tokenize(line):
            return line.replace('|', ' ').split()

        # get column names
        col_names = []
        for line in f:
            tokens = tokenize(line)
            if tokens[0] == "CPU":
                for token in tokens:
                    col_names.append(token)
                break

        # parse per-core stat
        for line in f:
            # read stat
            stat_dict = {}
            for (col_name, token) in zip(col_names, tokenize(line)):
                stat_dict[col_name] = token
            # rearragne stat
            def get_kv(dict, k):
                return (k, dict[k])
            col_order = ('C0', 'POLL', 'C1', 'C2', 'C3', 'Freq')
            stat_list = []
            for col in col_order:
                stat_list.append( (col, float(stat_dict[col])) )
            stat_dict['CPU'] = int(stat_dict['CPU'])
            cpu_id = stat_dict['CPU']
            core_stat[cpu_id] = stat_list
        ncpu = cpu_id + 1
        core_stat = core_stat[0:ncpu]

        # aggregate per-core state to system-wide stat
        sw_stat = core_stat[0].copy()
        for cs in core_stat[1:]:
            for pos, c in enumerate(cs):
                sk, sv = sw_stat[pos]
                ck, cv = c
                sw_stat[pos] = (sk, sv + cv)
        for pos, s in enumerate(sw_stat):
            sk, sv = sw_stat[pos]
            sw_stat[pos] = (sk, sv / float(ncpu))

    return sw_stat, core_stat

def parse_log_ef(log, parse_tbl):
    # parse the log
    stat = {}
    with open(log, 'r') as f:
        for line in f:
            for key, (pos, name) in parse_tbl:
                if line.find(key) >= 0:
                    tokens = line.split()
                    stat[name] = float(str_to_nstr(tokens[pos]))

    # sort statistics for easier interpretation
    stat2 = []
    for key, (pos, name) in parse_tbl:
        stat2.append( (name, stat.get(name)) )
    return stat2

def procmon_stat_energy(args):
    parse_tbl = ( ("Joules", (2, "J")), 
                  ("seconds", (0, "__seconds")) )

    log = get_log_name(args, "energy")
    stat = parse_log_ef(log, parse_tbl)
    # add energy (J/sec)
    sec = float(stat[1][1])
    stat[1] = ("J/sec", stat[0][1]/sec)
    return stat

def procmon_stat_perf(args):
    parse_tbl = ( ("seconds time elapsed", (0, "time (sec)")),
                  ("cycles  ", (0, "cycles")),
                  ("instructions", (0, "instructions")),
                  ("instructions", (3, "ipc")),
                  ("stalled-cycles-frontend", (3, "frontend-stall (%)")),
                  ("stalled-cycles-backend", (3, "backend-stall (%)")),
                  ("branches", (0, "branches")),
                  ("branch-misses", (3, "branch-misses (%)")),
                  ("page-faults", (0, "page-faults")),
                  ("context-switches", (0, "context-switches")),
                  ("cpu-migrations", (0, "cpu-migrations")), )

    log = get_log_name(args, "perf")
    stat = parse_log_ef(log, parse_tbl)
    return stat

def get_cmd_options(argv):
    parser = argparse.ArgumentParser(
            prog = "procinsight",
            description = "Report CPU statistics and system-wide scheduling statistics",)
    parser.add_argument('-o', '--outdir', action='store', required=True,
                        help='output directory') 
    parser.add_argument('-l', '--log', action='store', required=True,
                        help='log file prefix') 
    parser.add_argument('-q', '--quiet', action='store_true',
                        help='do not print result to stdout' ) 
    args = parser.parse_args(argv)
    return args

if __name__ == "__main__":
    args = get_cmd_options(sys.argv[1:])
    report_procmons_in_csv(args)



