diff --git a/src/main/java/de/th_luebeck/ws25/Main.java b/src/main/java/de/th_luebeck/ws25/Main.java index 503720e..3732966 100644 --- a/src/main/java/de/th_luebeck/ws25/Main.java +++ b/src/main/java/de/th_luebeck/ws25/Main.java @@ -6,11 +6,74 @@ import java.nio.file.Path; import java.util.*; public class Main { - static void main(String[] args) { - /* ... */ + + private static final int NUM_BINS = 9; + private static final int WINDOW_LEN = 50; + private static final double PRIOR_N = 0.95; + private static final double PRIOR_A = 0.05; + private static final double LAPLACE = 1e-6; + + + static void main(String[] args) throws IOException { + List train = readTrainFile(Path.of(args[0])); + if (train.isEmpty()) { + System.err.println("YOU FOUND A TURTLE!!! HOWLY PAWS!!!"); + return; + } + + int minValue = Integer.MAX_VALUE; + int maxValue = Integer.MIN_VALUE; + for (LabeledSequence sequence : train) { + for (LabeledSequence labeledSequence : train) { + for (int v : labeledSequence.sequence()) { + if (v < minValue) minValue = v; + if (v > maxValue) maxValue = v; + } + } + } + + minValue = minValue - 1; + maxValue = maxValue + 1; + + int binWidth = (maxValue - minValue) / NUM_BINS; + + long[][] countsN = new long[NUM_BINS][NUM_BINS]; + long[][] countsA = new long[NUM_BINS][NUM_BINS]; + for (LabeledSequence ls : train) { + List s = ls.sequence(); + Integer prev = null; + for (int v : s) { + int st = valueToState(v, minValue, maxValue, NUM_BINS, binWidth); + if (prev != null) { + if (ls.label() == Label.ARRHYTHMIA) countsA[prev][st]++; + else countsN[prev][st]++; + } + prev = st; + } + } + double[][] Pn; + double[][] Pa; + Pn = countsToProb(countsN, LAPLACE); + Pa = countsToProb(countsA, LAPLACE); + + EvalFile eval = readEvalFile(Path.of(args[1])); + + List candidates = scoreAllWindows(eval.sequence(), minValue, maxValue, binWidth, Pa, Pn); + + List reported = selectNonOverlapping(candidates); + + System.out.println("REPORTED INTERVALS (start..end):"); + for (Interval it : reported) { + System.out.println("[" + it.start() + "," + it.end() + "]"); + } + + EvaluationResult res = evaluate(reported, eval.intervals()); + System.out.printf("Precision: %.4f%n", res.precision); + System.out.printf("Recall: %.4f%n", res.recall); + System.out.printf("F1: %.4f%n", res.f1Score); } - private List readTrainFile(Path path) throws IOException { + private static List readTrainFile(Path path) throws IOException { List outputSequences = new ArrayList<>(); List lines = Files.readAllLines(path); for (String line : lines) { @@ -36,7 +99,7 @@ public class Main { return outputSequences; } - public static EvalFile readEvalFile(Path path) throws IOException { + private static EvalFile readEvalFile(Path path) throws IOException { List lines = Files.readAllLines(path); if (lines.size() < 2) { throw new IOException("Eval file must contain at least two lines"); @@ -48,7 +111,7 @@ public class Main { } - public static int[] parseInts(String inputString) { + private static int[] parseInts(String inputString) { List numbers = new ArrayList<>(); Scanner scanner = new Scanner(inputString); scanner.useDelimiter("[\\s,\\[\\]]+"); @@ -65,7 +128,7 @@ public class Main { return numbers.stream().mapToInt(i -> i).toArray(); } - public static List parseIntervals(String inputString) { + private static List parseIntervals(String inputString) { int[] numbers = parseInts(inputString); List outputIntervals = new ArrayList<>(); for (int i = 0; i + 1 < numbers.length; i += 2) { @@ -79,4 +142,113 @@ public class Main { return outputIntervals; } + + private static EvaluationResult evaluate(List reported, List truth) { + int reportedCount = reported.size(); + int truthCount = truth.size(); + boolean[] truthMatched = new boolean[truthCount]; + int correct = 0; + + for (Interval r : reported) { + boolean found = false; + for (int i = 0; i < truthCount; i++) { + if (truthMatched[i]) continue; + Interval t = truth.get(i); + if (r.length() != WINDOW_LEN) continue; + int ov = r.overlapSize(t); + if (ov * 1.0 / WINDOW_LEN >= 0.8) { + truthMatched[i] = true; + correct++; + found = true; + break; + } + } + } + double precision = reportedCount == 0 ? 0.0 : (double)correct / reportedCount; + double recall = truthCount == 0 ? 0.0 : (double)correct / truthCount; + double f1 = (precision + recall) == 0 ? 0.0 : 2.0 * precision * recall / (precision + recall); + return new EvaluationResult(precision, recall, f1); + } + + private static List selectNonOverlapping(List candidates) { + List chosen = new ArrayList<>(); + candidates.sort((a,b) -> Double.compare(b.score, a.score)); + boolean[] occupied = new boolean[10000000]; + for (CandidateWindow cw : candidates) { + boolean overlaps = false; + for (Interval it : chosen) { + if (it.overlaps(cw.interval)) { overlaps = true; break; } + } + if (!overlaps) chosen.add(cw.interval); + } + + chosen.sort(Comparator.comparingInt(Interval::start)); + return chosen; + } + + private static List scoreAllWindows(List sequence, int minValue, int maxValue, int binWidth, double[][] Pa, double[][] Pn) { + List res = new ArrayList<>(); + int L = sequence.size(); + if (L < WINDOW_LEN) return res; + double logPriorA = Math.log(PRIOR_A); + double logPriorN = Math.log(PRIOR_N); + + for (int start = 0; start <= L - WINDOW_LEN; start++) { + int end = start + WINDOW_LEN - 1; + double logLikeA = 0.0; + double logLikeN = 0.0; + boolean ok = true; + int prevState = valueToState(sequence.get(start), minValue, maxValue, NUM_BINS, binWidth); + for (int t = start + 1; t <= end; t++) { + int curState = valueToState(sequence.get(t), minValue, maxValue, NUM_BINS, binWidth); + double pa = Pa[prevState][curState]; + double pn = Pn[prevState][curState]; + if (pa <= 0) pa = LAPLACE; + if (pn <= 0) pn = LAPLACE; + logLikeA += Math.log(pa); + logLikeN += Math.log(pn); + prevState = curState; + } + double logPostA = logLikeA + logPriorA; + double logPostN = logLikeN + logPriorN; + double score = logPostA - logPostN; + + double pA = Math.exp(logPostA) / (Math.exp(logPostA) + Math.exp(logPostN)); + if (pA > 0.8) { + res.add(new CandidateWindow(new Interval(start, end), score)); + } + + /* + if (score > 0) { + res.add(new CandidateWindow(new Interval(start, end), score)); + } + */ + } + return res; + } + + private static double[][] countsToProb(long[][] counts, double laplace) { + int n = counts.length; + double[][] P = new double[n][n]; + for (int i = 0; i < n; i++) { + double sum = 0.0; + for (int j = 0; j < n; j++) sum += counts[i][j] + laplace; + if (sum <= 0) { + double u = 1.0 / n; + for (int j = 0; j < n; j++) P[i][j] = u; + } else { + for (int j = 0; j < n; j++) P[i][j] = (counts[i][j] + laplace) / sum; + } + } + return P; + } + + private static int valueToState(int val, int minVal, int maxVal, int numBins, int binWidth) { + if (val <= minVal) return 0; + if (val >= maxVal) return numBins - 1; + int idx = (int)((val - minVal) / binWidth); + if (idx < 0) idx = 0; + if (idx >= numBins) idx = numBins - 1; + return idx; + } }