66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
"""Unit tests for src/tools/generate_splits.py"""
|
|
import pytest
|
|
from pathlib import Path
|
|
|
|
from src.tools.generate_splits import generate_splits_txt, SPLITS
|
|
|
|
|
|
@pytest.fixture
|
|
def splits_dir(tmp_path):
|
|
"""Create a minimal splits directory with label files."""
|
|
for split in ["train", "val"]:
|
|
label_dir = tmp_path / split / "labels"
|
|
label_dir.mkdir(parents=True)
|
|
for i in range(3):
|
|
(label_dir / f"img_{split}_{i:03d}.txt").write_text("0 0.5 0.5 0.1 0.1\n")
|
|
# leave "test" absent intentionally
|
|
return tmp_path
|
|
|
|
|
|
class TestGenerateSplitsTxt:
|
|
def test_creates_output_file(self, splits_dir, tmp_path):
|
|
out = tmp_path / "output" / "split.txt"
|
|
out.parent.mkdir()
|
|
generate_splits_txt(splits_dir, out)
|
|
assert out.exists()
|
|
|
|
def test_entries_have_correct_format(self, splits_dir, tmp_path):
|
|
out = tmp_path / "split.txt"
|
|
generate_splits_txt(splits_dir, out)
|
|
lines = [l for l in out.read_text().splitlines() if l]
|
|
for line in lines:
|
|
parts = line.split("/")
|
|
assert parts[0] in SPLITS, f"unexpected split name in: {line}"
|
|
assert parts[1] == "labels"
|
|
assert parts[2].endswith(".txt")
|
|
|
|
def test_entry_count_matches_label_files(self, splits_dir, tmp_path):
|
|
out = tmp_path / "split.txt"
|
|
generate_splits_txt(splits_dir, out)
|
|
lines = [l for l in out.read_text().splitlines() if l]
|
|
# 3 train + 3 val = 6; test was absent
|
|
assert len(lines) == 6
|
|
|
|
def test_train_entries_come_before_val(self, splits_dir, tmp_path):
|
|
out = tmp_path / "split.txt"
|
|
generate_splits_txt(splits_dir, out)
|
|
lines = [l for l in out.read_text().splitlines() if l]
|
|
splits_seen = [l.split("/")[0] for l in lines]
|
|
train_indices = [i for i, s in enumerate(splits_seen) if s == "train"]
|
|
val_indices = [i for i, s in enumerate(splits_seen) if s == "val"]
|
|
assert max(train_indices) < min(val_indices)
|
|
|
|
def test_missing_split_dir_is_skipped_gracefully(self, tmp_path):
|
|
# empty directory — no split sub-dirs at all
|
|
out = tmp_path / "split.txt"
|
|
generate_splits_txt(tmp_path, out)
|
|
lines = [l for l in out.read_text().splitlines() if l]
|
|
assert lines == []
|
|
|
|
def test_entries_are_sorted_within_split(self, splits_dir, tmp_path):
|
|
out = tmp_path / "split.txt"
|
|
generate_splits_txt(splits_dir, out)
|
|
lines = [l for l in out.read_text().splitlines() if l]
|
|
train_lines = [l for l in lines if l.startswith("train/")]
|
|
assert train_lines == sorted(train_lines)
|