#! /usr/bin/env ocaml

(* ASSUMPTIONS:
   - the 1-st command line argument (working directory):
     - designates an existing readable directory
     - which contains *.time and *.perf files produced by bench.sh script
   - the 2-nd command line argument (number of iterations):
     - is a positive integer
   - the 3-rd command line argument (minimal user time):
     - is a positive floating point number
   - the 4-th command line argument determines the name of the column according to which the resulting table will be sorted.
     Valid values are:
     - package_name
     - user_time_pdiff
   - the rest of the command line-arguments
     - are names of benchamarked Coq OPAM packages for which bench.sh script generated *.time and *.perf files
 *)

#use "topfind";;
#require "unix";;
#print_depth 100000000;;
#print_length 100000000;;

open Printf
open Unix
;;

let _ = Printexc.record_backtrace true
;;

type ('a,'b) pkg_timings = {
  user_time  : 'a;
  num_instr  : 'b;
  num_cycles : 'b;
  num_mem    : 'b;
  num_faults : 'b;
}
;;

let reduce_pkg_timings (m_f : 'a list -> 'c) (m_a : 'b list -> 'd) (t : ('a,'b) pkg_timings list) : ('c,'d) pkg_timings =
  { user_time  = m_f @@ List.map (fun x -> x.user_time)  t
  ; num_instr  = m_a @@ List.map (fun x -> x.num_instr)  t
  ; num_cycles = m_a @@ List.map (fun x -> x.num_cycles) t
  ; num_mem    = m_a @@ List.map (fun x -> x.num_mem)    t
  ; num_faults = m_a @@ List.map (fun x -> x.num_faults) t
  }
;;

(******************************************************************************)
(* BEGIN Copied from batteries, to remove *)
(******************************************************************************)
let run_and_read cmd =
  (* This code is before the open of BatInnerIO
     to avoid using batteries' wrapped IOs *)
  let string_of_file fn =
    let buff_size = 1024 in
    let buff = Buffer.create buff_size in
    let ic = open_in fn in
    let line_buff = Bytes.create buff_size in
    begin
      let was_read = ref (input ic line_buff 0 buff_size) in
      while !was_read <> 0 do
        Buffer.add_subbytes buff line_buff 0 !was_read;
        was_read := input ic line_buff 0 buff_size;
      done;
      close_in ic;
    end;
    Buffer.contents buff
  in
  let tmp_fn = Filename.temp_file "" "" in
  let cmd_to_run = cmd ^ " > " ^ tmp_fn in
  let status = Unix.system cmd_to_run in
  let output = string_of_file tmp_fn in
  Unix.unlink tmp_fn;
  (status, output)
;;

let ( %> ) f g x = g (f x)

let run = run_and_read %> snd

module Float = struct
  let nan = Pervasives.nan
end

module List = struct
  include List

  let rec init_tailrec_aux acc i n f =
    if i >= n then acc
    else init_tailrec_aux (f i :: acc) (i+1) n f

  let rec init_aux i n f =
    if i >= n then []
    else
      let r = f i in
      r :: init_aux (i+1) n f

  let rev_init_threshold =
    match Sys.backend_type with
    | Sys.Native | Sys.Bytecode -> 10_000
    (* We don't known the size of the stack, better be safe and assume it's small. *)
    | Sys.Other _ -> 50

  let init len f =
    if len < 0 then invalid_arg "List.init" else
    if len > rev_init_threshold then rev (init_tailrec_aux [] 0 len f)
    else init_aux 0 len f

  let rec drop n = function
    | _ :: l when n > 0 -> drop (n-1) l
    | l -> l

  let reduce f = function
    | [] ->
      invalid_arg "List.reduce: Empty List"
    | h :: t ->
      fold_left f h t

  let min l = reduce Pervasives.min l
  let max l = reduce Pervasives.max l

end
;;

module String = struct

  include String

  let rchop ?(n = 1) s =
    if n < 0 then
      invalid_arg "String.rchop: number of characters to chop is negative"
    else
      let slen = length s in
      if slen <= n then "" else sub s 0 (slen - n)

end
;;

module Table :
sig
  type header = string
  type row = string list list
  val print : header list -> row -> row list -> string
end =
struct
  type header = string

  type row = string list list

  let val_padding = 2
  (* Padding between data in the same row *)
  let row_padding = 1
  (* Padding between rows *)

  let homogeneous b = if b then () else failwith "Heterogeneous data"

  let vert_split (ls : 'a list list) =
    let split l = match l with
    | [] -> failwith "vert_split"
    | x :: l -> (x, l)
    in
    let ls = List.map split ls in
    List.split ls

  let justify n s =
    let len = String.length s in
    let () = assert (len <= n) in
    let lft = (n - len) / 2 in
    let rgt = n - lft - len in
    String.make lft ' ' ^ s ^ String.make rgt ' '

  let justify_row layout data =
    let map n s =
      let len = String.length s in
      let () = assert (len <= n) in
      (* Right align *)
      let pad = n - len in
      String.make pad ' ' ^ s
    in
    let data = List.map2 map layout data in
    String.concat (String.make val_padding ' ') data

  let angle hkind vkind = match hkind, vkind with
  | `Lft, `Top -> "┌"
  | `Rgt, `Top -> "┐"
  | `Mid, `Top -> "┬"
  | `Lft, `Mid -> "├"
  | `Rgt, `Mid -> "┤"
  | `Mid, `Mid -> "┼"
  | `Lft, `Bot -> "└"
  | `Rgt, `Bot -> "┘"
  | `Mid, `Bot -> "┴"

  let print_separator vkind col_size =
    let rec dashes n = if n = 0 then "" else "─" ^ dashes (n - 1) in
    let len = List.length col_size in
    let pad = dashes row_padding in
    let () = assert (0 < len) in
    let map n = dashes n in
    angle `Lft vkind ^ pad ^
    String.concat (pad ^ angle `Mid vkind ^ pad) (List.map map col_size) ^
    pad ^ angle `Rgt vkind

  let print_blank col_size =
    let len = List.length col_size in
    let () = assert (0 < len) in
    let pad = String.make row_padding ' ' in
    let map n = String.make n ' ' in
    "│" ^ pad ^ String.concat (pad ^ "│" ^ pad) (List.map map col_size) ^ pad ^ "│"

  let print_row row =
    let len = List.length row in
    let () = assert (0 < len) in
    let pad = String.make row_padding ' ' in
    "│" ^ pad ^ String.concat (pad ^ "│" ^ pad) row ^ pad ^ "│"

  (* Invariant : all rows must have the same shape *)

  let print (headers : header list) (top : row) (rows : row list) =
    (* Sanitize input *)
    let ncolums = List.length headers in
    let shape = ref None in
    let check row =
      let () = homogeneous (List.length row = ncolums) in
      let rshape : int list = List.map (fun data -> List.length data) row in
      match !shape with
      | None -> shape := Some rshape
      | Some s -> homogeneous (rshape = s)
    in
    let () = List.iter check rows in
    (* Compute layout *)
    let rec layout n (rows : row list) =
      if n = 0 then []
      else
        let (col, rows) = vert_split rows in
        let ans = layout (n - 1) rows in
        let data = ref None in
        let iter args =
          let size = List.map String.length args in
          match !data with
          | None -> data := Some size
          | Some s ->
            data := Some (List.map2 (fun len1 len2 -> max len1 len2) s size)
        in
        let () = List.iter iter col in
        let data = match !data with None -> [] | Some s -> s in
        data :: ans
    in
    let layout = layout ncolums (top::rows) in
    let map hd shape =
      let data_size = match shape with
      | [] -> 0
      | n :: shape -> List.fold_left (fun accu n -> accu + n + val_padding) n shape
      in
      max (String.length hd) data_size
    in
    let col_size = List.map2 map headers layout in
    (* Justify the data *)
    let headers = List.map2 justify col_size headers in
    let top = List.map2 justify col_size (List.map2 justify_row layout top) in
    let rows = List.map (fun row -> List.map2 justify col_size (List.map2 justify_row layout row)) rows in
    (* Print the table *)
    let lines =
      print_separator `Top col_size ::
      print_row headers ::
      print_blank col_size ::
      print_row top ::
      print_separator `Mid col_size ::
      List.map print_row rows @
      print_separator `Bot col_size ::
      []
    in
    String.concat "\n" lines
end

(******************************************************************************)
(* END Copied from batteries, to remove *)
(******************************************************************************)

let add_timings a b =
  { user_time = a.user_time +. b.user_time;
    num_instr = a.num_instr + b.num_instr;
    num_cycles = a.num_cycles + b.num_cycles;
    num_mem = a.num_mem + b.num_mem;
    num_faults = a.num_faults + b.num_faults;
  }

let mk_pkg_timings work_dir pkg_name suffix iteration =
  let command_prefix = "cat " ^ work_dir ^ "/" ^ pkg_name ^ suffix ^ string_of_int iteration in
  let ncoms = command_prefix ^ ".ncoms" |> run |> String.rchop ~n:1 |> int_of_string in
  let timings = List.init ncoms (fun ncom ->
      let command_prefix = command_prefix ^ "." ^ string_of_int (ncom+1) in
      let time_command_output = command_prefix ^ ".time" |> run |> String.rchop ~n:1 |> String.split_on_char ' ' in

      let nth x i = List.nth i x in

      { user_time = time_command_output |> nth 0 |> float_of_string
      (* Perf can indeed be not supported in some systems, so we must fail gracefully *)
      ; num_instr =
          (try command_prefix ^ ".perf | grep instructions:u | awk '{print $1}' | sed 's/,//g'" |>
               run |> String.rchop ~n:1 |> int_of_string
           with Failure _ -> 0)
      ; num_cycles =
          (try command_prefix ^ ".perf | grep cycles:u | awk '{print $1}' | sed 's/,//g'" |>
               run |> String.rchop ~n:1 |> int_of_string
           with Failure _ -> 0)
      ; num_mem = time_command_output |> nth 1 |> int_of_string
      ; num_faults = time_command_output |> nth 2 |> int_of_string
      })
  in
  match timings with
  | [] -> assert false
  | timing :: rest -> List.fold_left add_timings timing rest
;;

(* process command line paramters *)
assert (Array.length Sys.argv > 5);
let work_dir = Sys.argv.(1) in
let num_of_iterations = int_of_string Sys.argv.(2) in
let new_coq_version = Sys.argv.(3) in
let old_coq_version = Sys.argv.(4) in
let minimal_user_time = float_of_string Sys.argv.(5) in
let sorting_column = Sys.argv.(6) in
let coq_opam_packages = Sys.argv |> Array.to_list |> List.drop 7 in

(* ASSUMPTIONS:

   "working_dir" contains all the files produced by the following command:

      two_points_on_the_same_branch.sh $working_directory $coq_repository $coq_branch[:$new:$old] $num_of_iterations coq_opam_package_1 coq_opam_package_2 ... coq_opam_package_N
-sf
*)

(* Run a given bash command;
   wait until it termines;
   check if its exit status is 0;
   return its whole stdout as a string. *)

let proportional_difference_of_integers new_value old_value =
  if old_value = 0
  then Float.nan
  else float_of_int (new_value - old_value) /. float_of_int old_value *. 100.0
in

(* parse the *.time and *.perf files *)
coq_opam_packages
|> List.map
     (fun package_name ->
       package_name,(* compilation_results_for_NEW : (float * int * int * int) list *)
       List.init num_of_iterations succ |> List.map (mk_pkg_timings work_dir package_name ".NEW."),
       List.init num_of_iterations succ |> List.map (mk_pkg_timings work_dir package_name ".OLD."))

(* from the list of measured values, select just the minimal ones *)

|> List.map
  (fun ((package_name : string),
        (new_measurements : (float, int) pkg_timings list),
        (old_measurements : (float, int) pkg_timings list)) ->
    let f_min : float list -> float = List.min in
    let i_min : int list -> int = List.min in
    package_name,
    reduce_pkg_timings f_min i_min new_measurements,
    reduce_pkg_timings f_min i_min old_measurements
  )

(* compute the "proportional differences in % of the NEW measurement and the OLD measurement" of all measured values *)
|> List.map
     (fun (package_name, new_t, old_t) ->
       package_name, new_t, old_t,
       { user_time  = (new_t.user_time -. old_t.user_time) /. old_t.user_time *. 100.0
       ; num_instr  = proportional_difference_of_integers new_t.num_instr  old_t.num_instr
       ; num_cycles = proportional_difference_of_integers new_t.num_cycles old_t.num_cycles
       ; num_mem    = proportional_difference_of_integers new_t.num_mem    old_t.num_mem
       ; num_faults = proportional_difference_of_integers new_t.num_faults old_t.num_faults
       })

(* sort the table with results *)
|> List.sort
     (match sorting_column with
      | "user_time_pdiff" ->
        fun (_,_,_,perf1) (_,_,_,perf2) ->
          compare perf1.user_time perf2.user_time
      | "package_name" ->
        fun (n1,_,_,_) (n2,_,_,_) -> compare n1 n2
      | _ ->
        assert false
     )

(* Keep only measurements that took at least "minimal_user_time" (in seconds). *)

|> List.filter
     (fun (_, new_t, old_t, _) ->
        minimal_user_time <= new_t.user_time && minimal_user_time <= old_t.user_time)

(* Below we take the measurements and format them to stdout. *)

|> List.map begin fun (package_name, new_t, old_t, perc) ->

  let precision = 2 in
  let prf f = Printf.sprintf "%.*f" precision f in
  let pri n = Printf.sprintf "%d" n in

  [
    [ package_name ];
    [ prf new_t.user_time; prf old_t.user_time; prf perc.user_time ];
    [ pri new_t.num_cycles; pri old_t.num_cycles; prf perc.num_cycles ];
    [ pri new_t.num_instr; pri old_t.num_instr; prf perc.num_instr ];
    [ pri new_t.num_mem; pri old_t.num_mem; prf perc.num_mem ];
    [ pri new_t.num_faults; pri old_t.num_faults; prf perc.num_faults ];
  ]

  end

|> fun measurements ->

    let headers = [
      "";
      "user time [s]";
      "CPU cycles";
      "CPU instructions";
      "max resident mem [KB]";
      "mem faults";
    ] in

    let descr = ["NEW"; "OLD"; "PDIFF"] in
    let top = [ [ "package_name" ]; descr; descr; descr; descr; descr ] in

    printf "%s%!" (Table.print headers top measurements)
;

(* ejgallego: disable this as it is very verbose and brings up little info in the log. *)
if false then begin
printf "

PDIFF = proportional difference between measurements done for the NEW and the OLD Coq version
      = (NEW_measurement - OLD_measurement) / OLD_measurement * 100%%

NEW = %s
OLD = %s

Columns:

  1. user time [s]

     Total number of CPU-seconds that the process used directly (in user mode), in seconds.
     (In other words, \"%%U\" quantity provided by the \"/usr/bin/time\" command.)

  2. CPU cycles

     Total number of CPU-cycles that the process used directly (in user mode).
     (In other words, \"cycles:u\" quantity provided by the \"/usr/bin/perf\" command.)

  3. CPU instructions

     Total number of CPU-instructions that the process used directly (in user mode).
     (In other words, \"instructions:u\" quantity provided by the \"/usr/bin/perf\" command.)

  4. max resident mem [KB]

     Maximum resident set size of the process during its lifetime, in Kilobytes.
     (In other words, \"%%M\" quantity provided by the \"/usr/bin/time\" command.)

  5. mem faults

     Number of major, or I/O-requiring, page faults that occurred while the process was running.
     These are faults where the page has actually migrated out of primary memory.
     (In other words, \"%%F\" quantity provided by the \"/usr/bin/time\" command.)

" new_coq_version old_coq_version;
end
