feat: enhance Main class with evaluation and scoring methods for intervals
This commit is contained in:
@@ -6,11 +6,74 @@ import java.nio.file.Path;
|
|||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
public class Main {
|
public class Main {
|
||||||
static void main(String[] args) {
|
|
||||||
/* ... */
|
private static final int NUM_BINS = 9;
|
||||||
|
private static final int WINDOW_LEN = 50;
|
||||||
|
private static final double PRIOR_N = 0.95;
|
||||||
|
private static final double PRIOR_A = 0.05;
|
||||||
|
private static final double LAPLACE = 1e-6;
|
||||||
|
|
||||||
|
|
||||||
|
static void main(String[] args) throws IOException {
|
||||||
|
List<LabeledSequence> train = readTrainFile(Path.of(args[0]));
|
||||||
|
if (train.isEmpty()) {
|
||||||
|
System.err.println("YOU FOUND A TURTLE!!! HOWLY PAWS!!!");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int minValue = Integer.MAX_VALUE;
|
||||||
|
int maxValue = Integer.MIN_VALUE;
|
||||||
|
for (LabeledSequence sequence : train) {
|
||||||
|
for (LabeledSequence labeledSequence : train) {
|
||||||
|
for (int v : labeledSequence.sequence()) {
|
||||||
|
if (v < minValue) minValue = v;
|
||||||
|
if (v > maxValue) maxValue = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
minValue = minValue - 1;
|
||||||
|
maxValue = maxValue + 1;
|
||||||
|
|
||||||
|
int binWidth = (maxValue - minValue) / NUM_BINS;
|
||||||
|
|
||||||
|
long[][] countsN = new long[NUM_BINS][NUM_BINS];
|
||||||
|
long[][] countsA = new long[NUM_BINS][NUM_BINS];
|
||||||
|
for (LabeledSequence ls : train) {
|
||||||
|
List<Integer> s = ls.sequence();
|
||||||
|
Integer prev = null;
|
||||||
|
for (int v : s) {
|
||||||
|
int st = valueToState(v, minValue, maxValue, NUM_BINS, binWidth);
|
||||||
|
if (prev != null) {
|
||||||
|
if (ls.label() == Label.ARRHYTHMIA) countsA[prev][st]++;
|
||||||
|
else countsN[prev][st]++;
|
||||||
|
}
|
||||||
|
prev = st;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
double[][] Pn;
|
||||||
|
double[][] Pa;
|
||||||
|
Pn = countsToProb(countsN, LAPLACE);
|
||||||
|
Pa = countsToProb(countsA, LAPLACE);
|
||||||
|
|
||||||
|
EvalFile eval = readEvalFile(Path.of(args[1]));
|
||||||
|
|
||||||
|
List<CandidateWindow> candidates = scoreAllWindows(eval.sequence(), minValue, maxValue, binWidth, Pa, Pn);
|
||||||
|
|
||||||
|
List<Interval> reported = selectNonOverlapping(candidates);
|
||||||
|
|
||||||
|
System.out.println("REPORTED INTERVALS (start..end):");
|
||||||
|
for (Interval it : reported) {
|
||||||
|
System.out.println("[" + it.start() + "," + it.end() + "]");
|
||||||
|
}
|
||||||
|
|
||||||
|
EvaluationResult res = evaluate(reported, eval.intervals());
|
||||||
|
System.out.printf("Precision: %.4f%n", res.precision);
|
||||||
|
System.out.printf("Recall: %.4f%n", res.recall);
|
||||||
|
System.out.printf("F1: %.4f%n", res.f1Score);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<LabeledSequence> readTrainFile(Path path) throws IOException {
|
private static List<LabeledSequence> readTrainFile(Path path) throws IOException {
|
||||||
List<LabeledSequence> outputSequences = new ArrayList<>();
|
List<LabeledSequence> outputSequences = new ArrayList<>();
|
||||||
List<String> lines = Files.readAllLines(path);
|
List<String> lines = Files.readAllLines(path);
|
||||||
for (String line : lines) {
|
for (String line : lines) {
|
||||||
@@ -36,7 +99,7 @@ public class Main {
|
|||||||
return outputSequences;
|
return outputSequences;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static EvalFile readEvalFile(Path path) throws IOException {
|
private static EvalFile readEvalFile(Path path) throws IOException {
|
||||||
List<String> lines = Files.readAllLines(path);
|
List<String> lines = Files.readAllLines(path);
|
||||||
if (lines.size() < 2) {
|
if (lines.size() < 2) {
|
||||||
throw new IOException("Eval file must contain at least two lines");
|
throw new IOException("Eval file must contain at least two lines");
|
||||||
@@ -48,7 +111,7 @@ public class Main {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static int[] parseInts(String inputString) {
|
private static int[] parseInts(String inputString) {
|
||||||
List<Integer> numbers = new ArrayList<>();
|
List<Integer> numbers = new ArrayList<>();
|
||||||
Scanner scanner = new Scanner(inputString);
|
Scanner scanner = new Scanner(inputString);
|
||||||
scanner.useDelimiter("[\\s,\\[\\]]+");
|
scanner.useDelimiter("[\\s,\\[\\]]+");
|
||||||
@@ -65,7 +128,7 @@ public class Main {
|
|||||||
return numbers.stream().mapToInt(i -> i).toArray();
|
return numbers.stream().mapToInt(i -> i).toArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<Interval> parseIntervals(String inputString) {
|
private static List<Interval> parseIntervals(String inputString) {
|
||||||
int[] numbers = parseInts(inputString);
|
int[] numbers = parseInts(inputString);
|
||||||
List<Interval> outputIntervals = new ArrayList<>();
|
List<Interval> outputIntervals = new ArrayList<>();
|
||||||
for (int i = 0; i + 1 < numbers.length; i += 2) {
|
for (int i = 0; i + 1 < numbers.length; i += 2) {
|
||||||
@@ -79,4 +142,113 @@ public class Main {
|
|||||||
|
|
||||||
return outputIntervals;
|
return outputIntervals;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static EvaluationResult evaluate(List<Interval> reported, List<Interval> truth) {
|
||||||
|
int reportedCount = reported.size();
|
||||||
|
int truthCount = truth.size();
|
||||||
|
boolean[] truthMatched = new boolean[truthCount];
|
||||||
|
int correct = 0;
|
||||||
|
|
||||||
|
for (Interval r : reported) {
|
||||||
|
boolean found = false;
|
||||||
|
for (int i = 0; i < truthCount; i++) {
|
||||||
|
if (truthMatched[i]) continue;
|
||||||
|
Interval t = truth.get(i);
|
||||||
|
if (r.length() != WINDOW_LEN) continue;
|
||||||
|
int ov = r.overlapSize(t);
|
||||||
|
if (ov * 1.0 / WINDOW_LEN >= 0.8) {
|
||||||
|
truthMatched[i] = true;
|
||||||
|
correct++;
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
double precision = reportedCount == 0 ? 0.0 : (double)correct / reportedCount;
|
||||||
|
double recall = truthCount == 0 ? 0.0 : (double)correct / truthCount;
|
||||||
|
double f1 = (precision + recall) == 0 ? 0.0 : 2.0 * precision * recall / (precision + recall);
|
||||||
|
return new EvaluationResult(precision, recall, f1);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Interval> selectNonOverlapping(List<CandidateWindow> candidates) {
|
||||||
|
List<Interval> chosen = new ArrayList<>();
|
||||||
|
candidates.sort((a,b) -> Double.compare(b.score, a.score));
|
||||||
|
boolean[] occupied = new boolean[10000000];
|
||||||
|
for (CandidateWindow cw : candidates) {
|
||||||
|
boolean overlaps = false;
|
||||||
|
for (Interval it : chosen) {
|
||||||
|
if (it.overlaps(cw.interval)) { overlaps = true; break; }
|
||||||
|
}
|
||||||
|
if (!overlaps) chosen.add(cw.interval);
|
||||||
|
}
|
||||||
|
|
||||||
|
chosen.sort(Comparator.comparingInt(Interval::start));
|
||||||
|
return chosen;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<CandidateWindow> scoreAllWindows(List<Integer> sequence, int minValue, int maxValue, int binWidth, double[][] Pa, double[][] Pn) {
|
||||||
|
List<CandidateWindow> res = new ArrayList<>();
|
||||||
|
int L = sequence.size();
|
||||||
|
if (L < WINDOW_LEN) return res;
|
||||||
|
double logPriorA = Math.log(PRIOR_A);
|
||||||
|
double logPriorN = Math.log(PRIOR_N);
|
||||||
|
|
||||||
|
for (int start = 0; start <= L - WINDOW_LEN; start++) {
|
||||||
|
int end = start + WINDOW_LEN - 1;
|
||||||
|
double logLikeA = 0.0;
|
||||||
|
double logLikeN = 0.0;
|
||||||
|
boolean ok = true;
|
||||||
|
int prevState = valueToState(sequence.get(start), minValue, maxValue, NUM_BINS, binWidth);
|
||||||
|
for (int t = start + 1; t <= end; t++) {
|
||||||
|
int curState = valueToState(sequence.get(t), minValue, maxValue, NUM_BINS, binWidth);
|
||||||
|
double pa = Pa[prevState][curState];
|
||||||
|
double pn = Pn[prevState][curState];
|
||||||
|
if (pa <= 0) pa = LAPLACE;
|
||||||
|
if (pn <= 0) pn = LAPLACE;
|
||||||
|
logLikeA += Math.log(pa);
|
||||||
|
logLikeN += Math.log(pn);
|
||||||
|
prevState = curState;
|
||||||
|
}
|
||||||
|
double logPostA = logLikeA + logPriorA;
|
||||||
|
double logPostN = logLikeN + logPriorN;
|
||||||
|
double score = logPostA - logPostN;
|
||||||
|
|
||||||
|
double pA = Math.exp(logPostA) / (Math.exp(logPostA) + Math.exp(logPostN));
|
||||||
|
if (pA > 0.8) {
|
||||||
|
res.add(new CandidateWindow(new Interval(start, end), score));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
if (score > 0) {
|
||||||
|
res.add(new CandidateWindow(new Interval(start, end), score));
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static double[][] countsToProb(long[][] counts, double laplace) {
|
||||||
|
int n = counts.length;
|
||||||
|
double[][] P = new double[n][n];
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int j = 0; j < n; j++) sum += counts[i][j] + laplace;
|
||||||
|
if (sum <= 0) {
|
||||||
|
double u = 1.0 / n;
|
||||||
|
for (int j = 0; j < n; j++) P[i][j] = u;
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < n; j++) P[i][j] = (counts[i][j] + laplace) / sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return P;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int valueToState(int val, int minVal, int maxVal, int numBins, int binWidth) {
|
||||||
|
if (val <= minVal) return 0;
|
||||||
|
if (val >= maxVal) return numBins - 1;
|
||||||
|
int idx = (int)((val - minVal) / binWidth);
|
||||||
|
if (idx < 0) idx = 0;
|
||||||
|
if (idx >= numBins) idx = numBins - 1;
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user