"""Unit tests for src/tools/generate_splits.py""" import pytest from pathlib import Path from src.tools.generate_splits import generate_splits_txt, SPLITS @pytest.fixture def splits_dir(tmp_path): """Create a minimal splits directory with label files.""" for split in ["train", "val"]: label_dir = tmp_path / split / "labels" label_dir.mkdir(parents=True) for i in range(3): (label_dir / f"img_{split}_{i:03d}.txt").write_text("0 0.5 0.5 0.1 0.1\n") # leave "test" absent intentionally return tmp_path class TestGenerateSplitsTxt: def test_creates_output_file(self, splits_dir, tmp_path): out = tmp_path / "output" / "split.txt" out.parent.mkdir() generate_splits_txt(splits_dir, out) assert out.exists() def test_entries_have_correct_format(self, splits_dir, tmp_path): out = tmp_path / "split.txt" generate_splits_txt(splits_dir, out) lines = [l for l in out.read_text().splitlines() if l] for line in lines: parts = line.split("/") assert parts[0] in SPLITS, f"unexpected split name in: {line}" assert parts[1] == "labels" assert parts[2].endswith(".txt") def test_entry_count_matches_label_files(self, splits_dir, tmp_path): out = tmp_path / "split.txt" generate_splits_txt(splits_dir, out) lines = [l for l in out.read_text().splitlines() if l] # 3 train + 3 val = 6; test was absent assert len(lines) == 6 def test_train_entries_come_before_val(self, splits_dir, tmp_path): out = tmp_path / "split.txt" generate_splits_txt(splits_dir, out) lines = [l for l in out.read_text().splitlines() if l] splits_seen = [l.split("/")[0] for l in lines] train_indices = [i for i, s in enumerate(splits_seen) if s == "train"] val_indices = [i for i, s in enumerate(splits_seen) if s == "val"] assert max(train_indices) < min(val_indices) def test_missing_split_dir_is_skipped_gracefully(self, tmp_path): # empty directory — no split sub-dirs at all out = tmp_path / "split.txt" generate_splits_txt(tmp_path, out) lines = [l for l in out.read_text().splitlines() if l] assert lines == [] def test_entries_are_sorted_within_split(self, splits_dir, tmp_path): out = tmp_path / "split.txt" generate_splits_txt(splits_dir, out) lines = [l for l in out.read_text().splitlines() if l] train_lines = [l for l in lines if l.startswith("train/")] assert train_lines == sorted(train_lines)