Files
AareLC_train/tests/test_generate_splits.py
2026-04-14 16:07:31 +02:00

66 lines
2.6 KiB
Python

"""Unit tests for src/tools/generate_splits.py"""
import pytest
from pathlib import Path
from src.tools.generate_splits import generate_splits_txt, SPLITS
@pytest.fixture
def splits_dir(tmp_path):
"""Create a minimal splits directory with label files."""
for split in ["train", "val"]:
label_dir = tmp_path / split / "labels"
label_dir.mkdir(parents=True)
for i in range(3):
(label_dir / f"img_{split}_{i:03d}.txt").write_text("0 0.5 0.5 0.1 0.1\n")
# leave "test" absent intentionally
return tmp_path
class TestGenerateSplitsTxt:
def test_creates_output_file(self, splits_dir, tmp_path):
out = tmp_path / "output" / "split.txt"
out.parent.mkdir()
generate_splits_txt(splits_dir, out)
assert out.exists()
def test_entries_have_correct_format(self, splits_dir, tmp_path):
out = tmp_path / "split.txt"
generate_splits_txt(splits_dir, out)
lines = [l for l in out.read_text().splitlines() if l]
for line in lines:
parts = line.split("/")
assert parts[0] in SPLITS, f"unexpected split name in: {line}"
assert parts[1] == "labels"
assert parts[2].endswith(".txt")
def test_entry_count_matches_label_files(self, splits_dir, tmp_path):
out = tmp_path / "split.txt"
generate_splits_txt(splits_dir, out)
lines = [l for l in out.read_text().splitlines() if l]
# 3 train + 3 val = 6; test was absent
assert len(lines) == 6
def test_train_entries_come_before_val(self, splits_dir, tmp_path):
out = tmp_path / "split.txt"
generate_splits_txt(splits_dir, out)
lines = [l for l in out.read_text().splitlines() if l]
splits_seen = [l.split("/")[0] for l in lines]
train_indices = [i for i, s in enumerate(splits_seen) if s == "train"]
val_indices = [i for i, s in enumerate(splits_seen) if s == "val"]
assert max(train_indices) < min(val_indices)
def test_missing_split_dir_is_skipped_gracefully(self, tmp_path):
# empty directory — no split sub-dirs at all
out = tmp_path / "split.txt"
generate_splits_txt(tmp_path, out)
lines = [l for l in out.read_text().splitlines() if l]
assert lines == []
def test_entries_are_sorted_within_split(self, splits_dir, tmp_path):
out = tmp_path / "split.txt"
generate_splits_txt(splits_dir, out)
lines = [l for l in out.read_text().splitlines() if l]
train_lines = [l for l in lines if l.startswith("train/")]
assert train_lines == sorted(train_lines)