Compare commits
10 Commits
bcbfd8b85c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 4ee43157e1 | |||
| 0a2f9a729a | |||
| 0afaca5277 | |||
| 5ffafe09e7 | |||
| 8217c7b5e4 | |||
| 57621315f5 | |||
| bf9c4be620 | |||
| 30c54bc99b | |||
| 33c121fa75 | |||
| 8865e87b9e |
11
src/main/java/de/th_luebeck/ws25/CandidateWindow.java
Normal file
11
src/main/java/de/th_luebeck/ws25/CandidateWindow.java
Normal file
@@ -0,0 +1,11 @@
|
||||
package de.th_luebeck.ws25;
|
||||
|
||||
public class CandidateWindow {
|
||||
Interval interval;
|
||||
double score;
|
||||
|
||||
CandidateWindow(Interval interval, double score) {
|
||||
this.interval = interval;
|
||||
this.score = score;
|
||||
}
|
||||
}
|
||||
6
src/main/java/de/th_luebeck/ws25/EvalFile.java
Normal file
6
src/main/java/de/th_luebeck/ws25/EvalFile.java
Normal file
@@ -0,0 +1,6 @@
|
||||
package de.th_luebeck.ws25;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public record EvalFile(List<Integer> sequence, List<Interval> intervals) {
|
||||
}
|
||||
13
src/main/java/de/th_luebeck/ws25/EvaluationResult.java
Normal file
13
src/main/java/de/th_luebeck/ws25/EvaluationResult.java
Normal file
@@ -0,0 +1,13 @@
|
||||
package de.th_luebeck.ws25;
|
||||
|
||||
public class EvaluationResult {
|
||||
double precision;
|
||||
double recall;
|
||||
double f1Score;
|
||||
|
||||
EvaluationResult(double precision, double recall, double f1Score) {
|
||||
this.precision = precision;
|
||||
this.recall = recall;
|
||||
this.f1Score = f1Score;
|
||||
}
|
||||
}
|
||||
18
src/main/java/de/th_luebeck/ws25/Interval.java
Normal file
18
src/main/java/de/th_luebeck/ws25/Interval.java
Normal file
@@ -0,0 +1,18 @@
|
||||
package de.th_luebeck.ws25;
|
||||
|
||||
public record Interval(int start, int end) {
|
||||
|
||||
public int length() {
|
||||
return end - start + 1;
|
||||
}
|
||||
|
||||
public boolean overlaps(Interval other) {
|
||||
return this.start <= other.end && other.start <= this.end;
|
||||
}
|
||||
|
||||
public int overlapSize(Interval other) {
|
||||
int a = Math.max(this.start, other.start);
|
||||
int b = Math.min(this.end, other.end);
|
||||
return Math.max(0, b - a + 1);
|
||||
}
|
||||
}
|
||||
6
src/main/java/de/th_luebeck/ws25/LabeledSequence.java
Normal file
6
src/main/java/de/th_luebeck/ws25/LabeledSequence.java
Normal file
@@ -0,0 +1,6 @@
|
||||
package de.th_luebeck.ws25;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public record LabeledSequence(Label label, List<Integer> sequence) {
|
||||
}
|
||||
253
src/main/java/de/th_luebeck/ws25/Main.java
Normal file
253
src/main/java/de/th_luebeck/ws25/Main.java
Normal file
@@ -0,0 +1,253 @@
|
||||
package de.th_luebeck.ws25;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
|
||||
public class Main {
|
||||
|
||||
private static final int NUM_BINS = 14;
|
||||
private static final int WINDOW_LEN = 50;
|
||||
private static final double PRIOR_N = 0.95;
|
||||
private static final double PRIOR_A = 0.05;
|
||||
private static final double LAPLACE = 1.0;
|
||||
|
||||
|
||||
static void main(String[] args) throws IOException {
|
||||
List<LabeledSequence> train = readTrainFile(Path.of(args[0]));
|
||||
if (train.isEmpty()) {
|
||||
System.err.println("YOU FOUND A TURTLE!!! HOWLY PAWS!!!");
|
||||
return;
|
||||
}
|
||||
|
||||
int minValue = Integer.MAX_VALUE;
|
||||
int maxValue = Integer.MIN_VALUE;
|
||||
for (LabeledSequence ls : train) {
|
||||
for (int v : ls.sequence()) {
|
||||
if (v < minValue) minValue = v;
|
||||
if (v > maxValue) maxValue = v;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
minValue = minValue - 1;
|
||||
maxValue = maxValue + 1;
|
||||
|
||||
int binWidth = (maxValue - minValue) / NUM_BINS;
|
||||
|
||||
long[][] countsN = new long[NUM_BINS][NUM_BINS];
|
||||
long[][] countsA = new long[NUM_BINS][NUM_BINS];
|
||||
for (LabeledSequence ls : train) {
|
||||
List<Integer> s = ls.sequence();
|
||||
Integer prev = null;
|
||||
for (int v : s) {
|
||||
int st = valueToState(v, minValue, maxValue, NUM_BINS, binWidth);
|
||||
if (prev != null) {
|
||||
if (ls.label() == Label.ARRHYTHMIA) countsA[prev][st]++;
|
||||
else countsN[prev][st]++;
|
||||
}
|
||||
prev = st;
|
||||
}
|
||||
}
|
||||
double[][] Pn;
|
||||
double[][] Pa;
|
||||
Pn = countsToProb(countsN, LAPLACE);
|
||||
Pa = countsToProb(countsA, LAPLACE);
|
||||
|
||||
EvalFile eval = readEvalFile(Path.of(args[1]));
|
||||
|
||||
List<CandidateWindow> candidates = scoreAllWindows(eval.sequence(), minValue, maxValue, binWidth, Pa, Pn);
|
||||
|
||||
List<Interval> reported = selectNonOverlapping(candidates);
|
||||
|
||||
System.out.println("REPORTED INTERVALS (start..end):");
|
||||
for (Interval it : reported) {
|
||||
System.out.println("[" + it.start() + "," + it.end() + "]");
|
||||
}
|
||||
|
||||
EvaluationResult res = evaluate(reported, eval.intervals());
|
||||
System.out.printf("Precision: %.4f%n", res.precision);
|
||||
System.out.printf("Recall: %.4f%n", res.recall);
|
||||
System.out.printf("F1: %.4f%n", res.f1Score);
|
||||
}
|
||||
|
||||
private static List<LabeledSequence> readTrainFile(Path path) throws IOException {
|
||||
List<LabeledSequence> outputSequences = new ArrayList<>();
|
||||
List<String> lines = Files.readAllLines(path);
|
||||
for (String line : lines) {
|
||||
line = line.trim();
|
||||
if (line.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
String[] parts = line.split("\\s+", 2);
|
||||
if (parts.length < 2) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Label label = Label.NORMAL;
|
||||
if (parts[0].equals("A")) {
|
||||
label = Label.ARRHYTHMIA;
|
||||
}
|
||||
|
||||
List<Integer> values = Arrays.stream(parseInts(parts[1])).boxed().toList();
|
||||
outputSequences.add(new LabeledSequence(label, values));
|
||||
}
|
||||
|
||||
return outputSequences;
|
||||
}
|
||||
|
||||
private static EvalFile readEvalFile(Path path) throws IOException {
|
||||
List<String> lines = Files.readAllLines(path);
|
||||
if (lines.size() < 2) {
|
||||
throw new IOException("Eval file must contain at least two lines");
|
||||
}
|
||||
|
||||
List<Integer> sequence = Arrays.stream(parseInts(lines.get(0).trim())).boxed().toList();
|
||||
List<Interval> intervals = parseIntervals(lines.get(1));
|
||||
return new EvalFile(sequence, intervals);
|
||||
}
|
||||
|
||||
|
||||
private static int[] parseInts(String inputString) {
|
||||
List<Integer> numbers = new ArrayList<>();
|
||||
Scanner scanner = new Scanner(inputString);
|
||||
scanner.useDelimiter("[\\s,\\[\\]]+");
|
||||
while (scanner.hasNext()) {
|
||||
if (scanner.hasNextInt()) {
|
||||
numbers.add(scanner.nextInt());
|
||||
} else {
|
||||
scanner.next();
|
||||
}
|
||||
}
|
||||
|
||||
scanner.close();
|
||||
|
||||
return numbers.stream().mapToInt(i -> i).toArray();
|
||||
}
|
||||
|
||||
private static List<Interval> parseIntervals(String inputString) {
|
||||
int[] numbers = parseInts(inputString);
|
||||
List<Interval> outputIntervals = new ArrayList<>();
|
||||
for (int i = 0; i + 1 < numbers.length; i += 2) {
|
||||
int start = numbers[i], end = numbers[i + 1];
|
||||
if (start <= end) {
|
||||
outputIntervals.add(new Interval(start, end));
|
||||
} else {
|
||||
outputIntervals.add(new Interval(end, start));
|
||||
}
|
||||
}
|
||||
|
||||
return outputIntervals;
|
||||
}
|
||||
|
||||
private static EvaluationResult evaluate(List<Interval> reported, List<Interval> truth) {
|
||||
int reportedCount = reported.size();
|
||||
int truthCount = truth.size();
|
||||
boolean[] truthMatched = new boolean[truthCount];
|
||||
int correct = 0;
|
||||
|
||||
for (Interval r : reported) {
|
||||
boolean found = false;
|
||||
for (int i = 0; i < truthCount; i++) {
|
||||
if (truthMatched[i]) continue;
|
||||
Interval t = truth.get(i);
|
||||
if (r.length() != WINDOW_LEN) continue;
|
||||
int ov = r.overlapSize(t);
|
||||
if (ov * 1.0 / WINDOW_LEN >= 0.8) {
|
||||
truthMatched[i] = true;
|
||||
correct++;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
double precision = reportedCount == 0 ? 0.0 : (double)correct / reportedCount;
|
||||
double recall = truthCount == 0 ? 0.0 : (double)correct / truthCount;
|
||||
double f1 = (precision + recall) == 0 ? 0.0 : 2.0 * precision * recall / (precision + recall);
|
||||
return new EvaluationResult(precision, recall, f1);
|
||||
}
|
||||
|
||||
private static List<Interval> selectNonOverlapping(List<CandidateWindow> candidates) {
|
||||
List<Interval> chosen = new ArrayList<>();
|
||||
candidates.sort((a,b) -> Double.compare(b.score, a.score));
|
||||
boolean[] occupied = new boolean[10000000];
|
||||
for (CandidateWindow cw : candidates) {
|
||||
boolean overlaps = false;
|
||||
for (Interval it : chosen) {
|
||||
if (it.overlaps(cw.interval)) { overlaps = true; break; }
|
||||
}
|
||||
if (!overlaps) chosen.add(cw.interval);
|
||||
}
|
||||
|
||||
chosen.sort(Comparator.comparingInt(Interval::start));
|
||||
return chosen;
|
||||
}
|
||||
|
||||
private static List<CandidateWindow> scoreAllWindows(List<Integer> sequence, int minValue, int maxValue, int binWidth, double[][] Pa, double[][] Pn) {
|
||||
List<CandidateWindow> res = new ArrayList<>();
|
||||
int L = sequence.size();
|
||||
if (L < WINDOW_LEN) return res;
|
||||
double logPriorA = Math.log(PRIOR_A);
|
||||
double logPriorN = Math.log(PRIOR_N);
|
||||
|
||||
for (int start = 0; start <= L - WINDOW_LEN; start++) {
|
||||
int end = start + WINDOW_LEN - 1;
|
||||
double logLikeA = 0.0;
|
||||
double logLikeN = 0.0;
|
||||
boolean ok = true;
|
||||
int prevState = valueToState(sequence.get(start), minValue, maxValue, NUM_BINS, binWidth);
|
||||
for (int t = start + 1; t <= end; t++) {
|
||||
int curState = valueToState(sequence.get(t), minValue, maxValue, NUM_BINS, binWidth);
|
||||
double pa = Pa[prevState][curState];
|
||||
double pn = Pn[prevState][curState];
|
||||
if (pa <= 0) pa = LAPLACE;
|
||||
if (pn <= 0) pn = LAPLACE;
|
||||
logLikeA += Math.log(pa);
|
||||
logLikeN += Math.log(pn);
|
||||
prevState = curState;
|
||||
}
|
||||
double logPostA = logLikeA + logPriorA;
|
||||
double logPostN = logLikeN + logPriorN;
|
||||
double score = logPostA - logPostN;
|
||||
|
||||
double pA = Math.exp(logPostA) / (Math.exp(logPostA) + Math.exp(logPostN));
|
||||
if (pA > 0.8) {
|
||||
res.add(new CandidateWindow(new Interval(start, end), score));
|
||||
}
|
||||
|
||||
/*
|
||||
if (score > 0) {
|
||||
res.add(new CandidateWindow(new Interval(start, end), score));
|
||||
}
|
||||
*/
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
private static double[][] countsToProb(long[][] counts, double laplace) {
|
||||
int n = counts.length;
|
||||
double[][] P = new double[n][n];
|
||||
for (int i = 0; i < n; i++) {
|
||||
double sum = 0.0;
|
||||
for (int j = 0; j < n; j++) sum += counts[i][j] + laplace;
|
||||
if (sum <= 0) {
|
||||
double u = 1.0 / n;
|
||||
for (int j = 0; j < n; j++) P[i][j] = u;
|
||||
} else {
|
||||
for (int j = 0; j < n; j++) P[i][j] = (counts[i][j] + laplace) / sum;
|
||||
}
|
||||
}
|
||||
return P;
|
||||
}
|
||||
|
||||
private static int valueToState(int val, int minVal, int maxVal, int numBins, int binWidth) {
|
||||
if (val <= minVal) return 0;
|
||||
if (val >= maxVal) return numBins - 1;
|
||||
int idx = (int)((val - minVal) / binWidth);
|
||||
if (idx < 0) idx = 0;
|
||||
if (idx >= numBins) idx = numBins - 1;
|
||||
return idx;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user