feat: enhance Main class with evaluation and scoring methods for intervals

This commit is contained in:
2025-12-05 23:32:16 +01:00
parent 5ffafe09e7
commit 0afaca5277

View File

@@ -6,11 +6,74 @@ import java.nio.file.Path;
import java.util.*;
public class Main {
static void main(String[] args) {
/* ... */
private static final int NUM_BINS = 9;
private static final int WINDOW_LEN = 50;
private static final double PRIOR_N = 0.95;
private static final double PRIOR_A = 0.05;
private static final double LAPLACE = 1e-6;
static void main(String[] args) throws IOException {
List<LabeledSequence> train = readTrainFile(Path.of(args[0]));
if (train.isEmpty()) {
System.err.println("YOU FOUND A TURTLE!!! HOWLY PAWS!!!");
return;
}
int minValue = Integer.MAX_VALUE;
int maxValue = Integer.MIN_VALUE;
for (LabeledSequence sequence : train) {
for (LabeledSequence labeledSequence : train) {
for (int v : labeledSequence.sequence()) {
if (v < minValue) minValue = v;
if (v > maxValue) maxValue = v;
}
}
}
minValue = minValue - 1;
maxValue = maxValue + 1;
int binWidth = (maxValue - minValue) / NUM_BINS;
long[][] countsN = new long[NUM_BINS][NUM_BINS];
long[][] countsA = new long[NUM_BINS][NUM_BINS];
for (LabeledSequence ls : train) {
List<Integer> s = ls.sequence();
Integer prev = null;
for (int v : s) {
int st = valueToState(v, minValue, maxValue, NUM_BINS, binWidth);
if (prev != null) {
if (ls.label() == Label.ARRHYTHMIA) countsA[prev][st]++;
else countsN[prev][st]++;
}
prev = st;
}
}
double[][] Pn;
double[][] Pa;
Pn = countsToProb(countsN, LAPLACE);
Pa = countsToProb(countsA, LAPLACE);
EvalFile eval = readEvalFile(Path.of(args[1]));
List<CandidateWindow> candidates = scoreAllWindows(eval.sequence(), minValue, maxValue, binWidth, Pa, Pn);
List<Interval> reported = selectNonOverlapping(candidates);
System.out.println("REPORTED INTERVALS (start..end):");
for (Interval it : reported) {
System.out.println("[" + it.start() + "," + it.end() + "]");
}
EvaluationResult res = evaluate(reported, eval.intervals());
System.out.printf("Precision: %.4f%n", res.precision);
System.out.printf("Recall: %.4f%n", res.recall);
System.out.printf("F1: %.4f%n", res.f1Score);
}
private List<LabeledSequence> readTrainFile(Path path) throws IOException {
private static List<LabeledSequence> readTrainFile(Path path) throws IOException {
List<LabeledSequence> outputSequences = new ArrayList<>();
List<String> lines = Files.readAllLines(path);
for (String line : lines) {
@@ -36,7 +99,7 @@ public class Main {
return outputSequences;
}
public static EvalFile readEvalFile(Path path) throws IOException {
private static EvalFile readEvalFile(Path path) throws IOException {
List<String> lines = Files.readAllLines(path);
if (lines.size() < 2) {
throw new IOException("Eval file must contain at least two lines");
@@ -48,7 +111,7 @@ public class Main {
}
public static int[] parseInts(String inputString) {
private static int[] parseInts(String inputString) {
List<Integer> numbers = new ArrayList<>();
Scanner scanner = new Scanner(inputString);
scanner.useDelimiter("[\\s,\\[\\]]+");
@@ -65,7 +128,7 @@ public class Main {
return numbers.stream().mapToInt(i -> i).toArray();
}
public static List<Interval> parseIntervals(String inputString) {
private static List<Interval> parseIntervals(String inputString) {
int[] numbers = parseInts(inputString);
List<Interval> outputIntervals = new ArrayList<>();
for (int i = 0; i + 1 < numbers.length; i += 2) {
@@ -79,4 +142,113 @@ public class Main {
return outputIntervals;
}
private static EvaluationResult evaluate(List<Interval> reported, List<Interval> truth) {
int reportedCount = reported.size();
int truthCount = truth.size();
boolean[] truthMatched = new boolean[truthCount];
int correct = 0;
for (Interval r : reported) {
boolean found = false;
for (int i = 0; i < truthCount; i++) {
if (truthMatched[i]) continue;
Interval t = truth.get(i);
if (r.length() != WINDOW_LEN) continue;
int ov = r.overlapSize(t);
if (ov * 1.0 / WINDOW_LEN >= 0.8) {
truthMatched[i] = true;
correct++;
found = true;
break;
}
}
}
double precision = reportedCount == 0 ? 0.0 : (double)correct / reportedCount;
double recall = truthCount == 0 ? 0.0 : (double)correct / truthCount;
double f1 = (precision + recall) == 0 ? 0.0 : 2.0 * precision * recall / (precision + recall);
return new EvaluationResult(precision, recall, f1);
}
private static List<Interval> selectNonOverlapping(List<CandidateWindow> candidates) {
List<Interval> chosen = new ArrayList<>();
candidates.sort((a,b) -> Double.compare(b.score, a.score));
boolean[] occupied = new boolean[10000000];
for (CandidateWindow cw : candidates) {
boolean overlaps = false;
for (Interval it : chosen) {
if (it.overlaps(cw.interval)) { overlaps = true; break; }
}
if (!overlaps) chosen.add(cw.interval);
}
chosen.sort(Comparator.comparingInt(Interval::start));
return chosen;
}
private static List<CandidateWindow> scoreAllWindows(List<Integer> sequence, int minValue, int maxValue, int binWidth, double[][] Pa, double[][] Pn) {
List<CandidateWindow> res = new ArrayList<>();
int L = sequence.size();
if (L < WINDOW_LEN) return res;
double logPriorA = Math.log(PRIOR_A);
double logPriorN = Math.log(PRIOR_N);
for (int start = 0; start <= L - WINDOW_LEN; start++) {
int end = start + WINDOW_LEN - 1;
double logLikeA = 0.0;
double logLikeN = 0.0;
boolean ok = true;
int prevState = valueToState(sequence.get(start), minValue, maxValue, NUM_BINS, binWidth);
for (int t = start + 1; t <= end; t++) {
int curState = valueToState(sequence.get(t), minValue, maxValue, NUM_BINS, binWidth);
double pa = Pa[prevState][curState];
double pn = Pn[prevState][curState];
if (pa <= 0) pa = LAPLACE;
if (pn <= 0) pn = LAPLACE;
logLikeA += Math.log(pa);
logLikeN += Math.log(pn);
prevState = curState;
}
double logPostA = logLikeA + logPriorA;
double logPostN = logLikeN + logPriorN;
double score = logPostA - logPostN;
double pA = Math.exp(logPostA) / (Math.exp(logPostA) + Math.exp(logPostN));
if (pA > 0.8) {
res.add(new CandidateWindow(new Interval(start, end), score));
}
/*
if (score > 0) {
res.add(new CandidateWindow(new Interval(start, end), score));
}
*/
}
return res;
}
private static double[][] countsToProb(long[][] counts, double laplace) {
int n = counts.length;
double[][] P = new double[n][n];
for (int i = 0; i < n; i++) {
double sum = 0.0;
for (int j = 0; j < n; j++) sum += counts[i][j] + laplace;
if (sum <= 0) {
double u = 1.0 / n;
for (int j = 0; j < n; j++) P[i][j] = u;
} else {
for (int j = 0; j < n; j++) P[i][j] = (counts[i][j] + laplace) / sum;
}
}
return P;
}
private static int valueToState(int val, int minVal, int maxVal, int numBins, int binWidth) {
if (val <= minVal) return 0;
if (val >= maxVal) return numBins - 1;
int idx = (int)((val - minVal) / binWidth);
if (idx < 0) idx = 0;
if (idx >= numBins) idx = numBins - 1;
return idx;
}
}