#!/usr/bin/perl -w

BEGIN { require 'constants.pl'; }
#use constant ROBINSON_S_CONSTANT => 0.30;
#use constant N_SIGNIFICANT_TOKENS => 150;
#use constant ROBINSON_MIN_PROB_STRENGTH => 0.2;
#use constant PROB_BOUND_LOWER => 0.001;
#use constant PROB_BOUND_UPPER => 0.999;
#use constant USE_CHI_COMBINING => 0;
#use constant ROBINSON_X => 0.45;

# usage: draw-bayes-histogram [--spam=spam.log] [--nonspam=nonspam.log]
#               [--nocollapse] [--nozoom] [--buckets=20]
#
# or: draw-bayes-histogram spam.log nonspam.log   (backwards compatible)

use Getopt::Long;
use vars qw($opt_spam $opt_nonspam $opt_nocollapse $opt_nozoom
        $opt_buckets);

GetOptions("spam=s", "nonspam=s", "nocollapse", "nozoom", "buckets=i");

my $spam = $opt_spam;
if (!$spam && $ARGV[0] !~ /^\-/) { $spam = $ARGV[0]; }
if (!$spam) { $spam = "spam.log"; }

my $nonspam = $opt_nonspam;
if (!$nonspam && $ARGV[1] !~ /^\-/) { $nonspam = $ARGV[1]; }
if (!$nonspam) { $nonspam = "nonspam.log"; }


my $buckets = $opt_buckets || 25;
my $zoomfactor = 20;
my $range_lo = 0.0;
my $range_hi = 1.0;

# shamelessly nicked from spambayes' testing infrastructure; a system to
# compute the "cost" of a pair of thresholds and classifier.  I'm supporting
# it here so our stats are (at least a little) comparable against theirs.
#
# Note that they use a different way to run 10PCV tests; they train with 1
# bucket and test against 9, whereas the lit suggests doing the opposite,
# which is what we do; so the stats may still be unportable due to us getting
# better training and less testing.  TODO

my $best_cutoff_fp_weight = 10.0;
my $best_cutoff_fn_weight = 1.0;
my $best_cutoff_unsure_weight = 0.2;

%bux_sp = ();
%bux_ns = ();

my $step = ($range_hi - $range_lo) / $buckets;
my $i;
for ($i = $range_lo; $i <= $range_hi; $i += $step) {
  push (@buckets, $i);
  $bux_ns{$i} = $bux_sp{$i} = 0;
}

foreach my $file ($spam, $nonspam) {
  open (IN, "<$file") || die "Could not open file '$file': $!";

  my $isspam = 0; ($file eq $spam) and $isspam = 1;

  while (<IN>) {
    /^#Bayes-Raw-Counts: / or next;
    my $score = brc_line_to_score($_);

    my $bucket_id;
    foreach my $bucket (@buckets) {
      if ($score >= $bucket && $score < $bucket+$step) {
        $bucket_id = $bucket; last;
      }
    }

    if ($isspam) {
      $bux_sp{$bucket_id}++;
    } else {
      $bux_ns{$bucket_id}++;
    }
  }
}
draw_hist();
compute_thresholds();

sub compute_thresholds {
  my $max_sp = 0;
  my $max_ns = 0;
  my $tot_sp = 0;
  my $tot_ns = 0;
  foreach my $bucket (@buckets) {
    $tot_sp += $bux_sp{$bucket};
    if ($bux_sp{$bucket} > $max_sp) 
			  { $max_sp = $bux_sp{$bucket}; }
    $tot_ns += $bux_ns{$bucket};
    if ($bux_ns{$bucket} > $max_ns) 
			  { $max_ns = $bux_ns{$bucket}; }
  }

  my %results = ();
  for ($hamcutoff = 0; $hamcutoff < 1; $hamcutoff += $step) {
    for ($spamcutoff = 0; $spamcutoff < 1; $spamcutoff += $step) {
      my $fn = 0;
      my $fp = 0;
      my $unsure_sp = 0;
      my $unsure_ns = 0;

      for ($i = $range_lo; $i < $hamcutoff; $i += $step) {
	$fn += $bux_sp{$i};
      }
      # total up the unsures (between hamcutoff and spamcutoff)
      for ($i = $hamcutoff; $i < $spamcutoff; $i += $step) {
	$unsure_ns += $bux_ns{$i};
	$unsure_sp += $bux_sp{$i};
      }
      for ($i = $spamcutoff; $i < $range_hi; $i += $step) {
	$fp += $bux_ns{$i};
      }

      my $cost = ($fp * $best_cutoff_fp_weight) 
		  + ($fn * $best_cutoff_fn_weight)
		  + (($unsure_ns+$unsure_sp) * $best_cutoff_unsure_weight);

      $results{"$hamcutoff $spamcutoff"} = {
	'hamcutoff' => $hamcutoff,
	'spamcutoff'=> $spamcutoff,
	'cost' => $cost,
	'unsure_ns' => $unsure_ns,
	'unsure_sp' => $unsure_sp,
	'fp' => $fp,
	'fn' => $fn
      };
    }
  }

  my $count = 10;
  foreach my $r (sort { $a->{cost} <=> $b->{cost} } values %results) {
    printf "Threshold optimization for hamcutoff=%3.2f, spamcutoff=%3.2f: cost=\$%5.2f\n",
		  $r->{hamcutoff}, $r->{spamcutoff}, $r->{cost};
    printf "Total ham:spam:   %d:%d\n", $tot_ns, $tot_sp;

    printf "FP: %5d %5.3f%%    ", $r->{fp}, ($r->{fp}*100) / $tot_ns;
    printf "FN: %5d %5.3f%%\n", $r->{fn}, ($r->{fn}*100) / $tot_sp;

    my $unsure = $r->{unsure_ns} + $r->{unsure_sp};
    printf "Unsure: %5d %5.3f%%     ", $unsure,
				  ($unsure*100) / ($tot_sp+$tot_ns);
    printf "(ham: %5d %5.3f%%    ", $r->{unsure_ns},
				  ($r->{unsure_ns}*100) / ($tot_ns);
    printf "spam: %5d %5.3f%%)\n", $r->{unsure_sp},
				  ($r->{unsure_sp}*100) / ($tot_sp);

    # for TCR calc, treat "unsures" as ham
    # TODO: unsure_sp should probably be treated as spam, assuming
    # it'll fall in the 5.0-6.0 score range
    my $fn = $r->{unsure_sp} + $r->{fn};
    my $fp = $r->{fp};
    printf "TCRs:              l=1 %5.3f    l=5 %5.3f    l=9 %5.3f\n",
      tcr ($tot_sp - $fn, $fn, $fp, $tot_ns - $fp, 1),
      tcr ($tot_sp - $fn, $fn, $fp, $tot_ns - $fp, 5),
      tcr ($tot_sp - $fn, $fn, $fp, $tot_ns - $fp, 9);

    print "\n";
    last if ($count-- < 0);
  }
}

sub draw_hist {
  my $max_sp = 0;
  my $max_ns = 0;
  my $tot_sp = 0;
  my $tot_ns = 0;
  foreach my $bucket (@buckets) {
    $tot_sp += $bux_sp{$bucket};
    if ($bux_sp{$bucket} > $max_sp) 
			  { $max_sp = $bux_sp{$bucket}; }
    $tot_ns += $bux_ns{$bucket};
    if ($bux_ns{$bucket} > $max_ns) 
			  { $max_ns = $bux_ns{$bucket}; }
  }

  my $chars_in_line = 55;
  if ($opt_nozoom) {
    $chars_in_line += 10;
  }
  my $scale_sp = ($max_sp / $chars_in_line);
  my $scale_ns = ($max_ns / $chars_in_line);
  $scale_sp ||= 0.000001; $scale_ns ||= 0.000001;
  $tot_sp ||= 0.000001; $tot_ns ||= 0.000001;

  print STDOUT 
   "SCORE  NUMHIT   DETAIL     OVERALL HISTOGRAM  (. = ham, # = spam)\n";
# 0.000 (19.217%) ..........|....................

  foreach my $bucket (@buckets) {
    my $numdots;

    $numdots = int (($bux_ns{$bucket} / $scale_ns) + .5);
    my $line_ns = ('.' x $numdots);

    $numdots = int ((($bux_ns{$bucket}*$zoomfactor) / $scale_ns) + .5);
    my $zoomline_ns = ('.' x $numdots);
    $zoomline_ns = sprintf ("%-10s", substr ($zoomline_ns, 0, 10));

    $numdots = int (($bux_sp{$bucket} / $scale_sp) + .5);
    my $line_sp = ('#' x $numdots);

    $numdots = int ((($bux_sp{$bucket}*$zoomfactor) / $scale_sp) + .5);
    my $zoomline_sp = ('#' x $numdots);
    $zoomline_sp = sprintf ("%-10s", substr ($zoomline_sp, 0, 10));

    if (!$opt_nozoom) {
      $line_ns = $zoomline_ns.'|'.$line_ns;
      $line_sp = $zoomline_sp.'|'.$line_sp;
    }

    if ($bux_ns{$bucket} != 0 && !$opt_nocollapse) {
      printf STDOUT "%3.3f (%6.3f%%) %s\n", $bucket,
		  (($bux_ns{$bucket} / $tot_ns) * 100.0), $line_ns;
    }
    if ($bux_sp{$bucket} != 0 && !$opt_nocollapse) {
      printf STDOUT "%3.3f (%6.3f%%) %s\n", $bucket,
		  (($bux_sp{$bucket} / $tot_sp) * 100.0), $line_sp;
    }
  }
}


sub tcr {
  my ($nspamspam, $nspamlegit, $nlegitspam, $nlegitlegit, $lambda) = @_;

  my $nlegit = $nlegitspam+$nlegitlegit;
  my $nspam = $nspamspam+$nspamlegit;

  my $werr = ($lambda * $nlegitspam + $nspamlegit)
		  / ($lambda * $nlegit + $nspam);

  my $werr_base = $nspam
		  / ($lambda * $nlegit + $nspam);

  $werr ||= 0.000001;     # avoid / by 0
  my $tcr = $werr_base / $werr;

#my $sr = ($nspamspam / $nspam) * 100.0;
#my $sp = ($nspamspam / ($nspamspam + $nlegitspam)) * 100.0;

  $tcr;
}

###########################################################################

use constant ROBINSON_S_DOT_X => (ROBINSON_X * ROBINSON_S_CONSTANT);

$cached_probs = { };

sub brc_line_to_score {
  my ($line) = @_;

  $line =~ s/^\#Bayes-Raw-Counts:\s+ns=(\d+) nn=(\d+)\s+//;
  my $ns = $1;
  my $nn = $2;

  my @probs = ();
  while ($line =~ s/^\s*s=(\d+),n=(\d+)\s*//) {
    my $s = $1;
    my $n = $2;

    my $pw = cached_compute_prob_for_token ($s,$n,$ns,$nn);
    if ($pw < PROB_BOUND_LOWER) {
      push (@probs, PROB_BOUND_LOWER);
    } elsif ($pw > PROB_BOUND_UPPER) {
      push (@probs, PROB_BOUND_UPPER);
    } else {
      push (@probs, $pw);
    }
  }

  my $count = N_SIGNIFICANT_TOKENS;
  my @sorted = ();

  for (sort {
              abs($b - 0.5) <=> abs($a - 0.5)
            } @probs)
  {
    if ($count-- < 0) { last; }
    next if (abs($_ - 0.5) < ROBINSON_MIN_PROB_STRENGTH);
    push (@sorted, $_);
  }

  if ($#sorted < 0) { goto skip; }

  my $score;
  if (USE_CHI_COMBINING) {
    $score = chi_squared_probs_combine (@sorted);
  } else {
    $score = robinson_naive_bayes_probs_combine (@sorted);
  }

  return $score;

skip:
  return 0.5;           # nice and neutral
}

sub cached_compute_prob_for_token {
  my ($s, $n, $ns, $nn) = @_;

  my $prob;
  my $shash = $cached_probs->{$s};
  if (defined $shash) { $prob = $shash->{$n}; }
  if (defined $prob) { return $prob; }

  $prob = compute_prob_for_token($s,$n,$ns,$nn);

  if (defined $cached_probs->{$s}) {
    $cached_probs->{$s}->{$n} = $prob;
  } else {
    $cached_probs->{$s} = { $n => $prob };
  }
  return $prob;
}

sub compute_prob_for_token {
  my ($s, $n, $ns, $nn) = @_;

  if (!USE_ROBINSON_FX_EQUATION_FOR_LOW_FREQS) {
    return if ($s + $n < 10);      # ignore low-freq tokens
  }

  my $ratios = ($s / $ns);
  my $ration = ($n / $nn);
  my $prob;

  if ($ratios == 0 && $ration == 0) {
    warn "oops? ratios == ration == 0";
    return 0.5;
  } else {
    $prob = ($ratios) / ($ration + $ratios);
  }

  if (USE_ROBINSON_FX_EQUATION_FOR_LOW_FREQS) {
    # use Robinson's f(x) equation for low-n tokens, instead of just
    # ignoring them
    my $robn = $s+$n;
    $prob = (ROBINSON_S_DOT_X + ($robn * $prob)) /
                  (ROBINSON_S_CONSTANT + $robn);
  }

  return $prob;
}

sub robinson_naive_bayes_probs_combine {
  my (@sorted) = @_;

  my $wc = scalar @sorted;
  my $P = 1;
  my $Q = 1;

  foreach my $pw (@sorted) {
    $P *= (1-$pw);
    $Q *= $pw;
  }
  $P = 1 - ($P ** (1 / $wc));
  $Q = 1 - ($Q ** (1 / $wc));
  return (1 + ($P - $Q) / ($P + $Q)) / 2.0;
}

###########################################################################

# Chi-squared function
sub chi2q {
  my ($x2, $v) = @_;

  die "v must be even in chi2q(x2, v)" if $v & 1;
  my $m = $x2 / 2.0;
  my ($sum, $term);
  $sum = $term = exp(0 - $m);
  for my $i (1 .. ($v >> 2)) {
    $term *= $m / $i;
    $sum += $term;
  }
  return $sum < 1.0 ? $sum : 1.0;
}

# Chi-Squared method. Produces mostly boolean $result,
# but with a grey area.
sub chi_squared_probs_combine  {
  my (@sorted) = @_;
  # @sorted contains an array of the probabilities

  my ($H, $S);
  my ($Hexp, $Sexp);
  $H = $S = 1.0;
  $Hexp = $Sexp = 0;

  my $num_clues = @sorted;
  use POSIX qw(frexp);

  foreach my $prob (@sorted) {
    $S *= 1.0 - $prob;
    $H *= $prob;
    if ($S < 1e-200) {
      my $e;
      ($S, $e) = frexp($S);
      $Sexp += $e;
    }
    if ($H < 1e-200) {
      my $e;
      ($H, $e) = frexp($H);
      $Hexp += $e;
    }
  }

  use constant LN2 => log(2);

  $S = log($S) + $Sexp + LN2;
  $H = log($H) + $Hexp + LN2;

  my $result;
  if ($num_clues) {
    $S = 1.0 - chi2q(-2.0 * $S, 2 * $num_clues);
    $H = 1.0 - chi2q(-2.0 * $H, 2 * $num_clues);
    $result = (($S - $H) + 1.0) / 2.0;
  } else {
    $result = 0.5;
  }

  return $result;
}

