diff --git a/tests/test_utils_tqdm_mod.py b/tests/test_utils_tqdm_mod.py index 48eb208a5..6a8ecf837 100644 --- a/tests/test_utils_tqdm_mod.py +++ b/tests/test_utils_tqdm_mod.py @@ -14,47 +14,82 @@ from slic.utils.tqdm_mod import * def extract_last_line(output): return output.strip().splitlines()[-1] -# Test +def get_bar_visual(line): + try: + return line.split("|")[1] + except IndexError: + return "" + +# Tests def test_complete_progress_bar(): f = io.StringIO() with redirect_stdout(f): - for _ in tqdm_mod(range(3), desc="TestBar", file=f, dynamic_ncols=False): + for _ in tqdm_mod(range(3), desc="TestBar", file=f): sleep(0.001) - last_line = extract_last_line(f.getvalue()) - assert last_line == "TestBar: 100%|########################| 3.00/3.00 [00:00<00:00, 3.0 Hz]" + lines = extract_lines(f.getvalue(), "TestBar") + last = lines[-1] + bar = get_bar_visual(last) + + assert last.startswith("TestBar: 100%") + assert "3.00/3.00" in last + assert "Hz" in last + + # Check that the bar us full + assert len(bar.replace("█", "").replace("#", "").strip()) == 0, f"Bar not full: '{bar}'" + def test_set_progress_multiple_points(): f = io.StringIO() with redirect_stdout(f): - bar = tqdm_mod(total=5, desc="SetBar", file=f, dynamic_ncols=False) + bar = tqdm_mod(total=5, desc="SetBar", file=f) bar.set(1.0) bar.set(2.0) bar.set(3.5) bar.set(5.0) bar.close() - lines = f.getvalue().strip().splitlines() - setbar_lines = [line for line in lines if "SetBar:" in line] + lines = extract_lines(f.getvalue(), "SetBar") - assert setbar_lines[0] == "SetBar: 20%|#####6 | 1.00/5.00 [00:00<00:00, 1.0 Hz]" - assert setbar_lines[1] == "SetBar: 40%|###########2 | 2.00/5.00 [00:00<00:00, 2.0 Hz]" - assert setbar_lines[2] == "SetBar: 70%|###################7 | 3.50/5.00 [00:00<00:00, 3.5 Hz]" - assert setbar_lines[3] == "SetBar: 100%|##############################| 5.00/5.00 [00:00<00:00, 5.0 Hz]" + for i, (expected_progress, expected_value) in enumerate([ + ("20%", "1.00/5.00"), + ("40%", "2.00/5.00"), + ("70%", "3.50/5.00"), + ("100%", "5.00/5.00"), + ]): + assert lines[i].startswith(f"SetBar:{expected_progress}") + assert expected_value in lines[i] + bar = get_bar_visual(lines[i]) + assert len(bar.strip()) > 0 def test_custom_unit(): f = io.StringIO() with redirect_stdout(f): - for _ in tqdm_mod(range(4), desc="StepBar", unit="step", file=f, dynamic_ncols=False): + for _ in tqdm_mod(range(4), desc="StepBar", unit="step", file=f): sleep(0.001) - last_line = extract_last_line(f.getvalue()) - assert last_line == "StepBar: 100%|########################| 4.00/4.00 [00:00<00:00, 4.0step/s]" + lines = extract_lines(f.getvalue(), "StepBar") + last = lines[-1] + bar = get_bar_visual(last) + + assert last.startswith("StepBar: 100%") + assert "4.00/4.00" in last + assert "step/s" in last + assert len(bar.strip()) > 0 def test_clamp_above_total(): f = io.StringIO() with redirect_stdout(f): - bar = tqdm_mod(total=10, desc="ClampBar", file=f, dynamic_ncols=False) - bar.set(12) # Clamp to 10 + bar = tqdm_mod(total=10, desc="ClampBar", file=f) + bar.set(12) bar.close() - last_line = extract_last_line(f.getvalue()) - assert last_line == "ClampBar: 100%|##############################| 10.0/10.0 [00:00<00:00, 10.0 Hz]" \ No newline at end of file + lines = extract_lines(f.getvalue(), "ClampBar") + last = lines[-1] + bar = get_bar_visual(last) + + assert last.startswith("ClampBar: 100%") + assert "10.0/10.0" in last + assert "Hz" in last + + # Check that the bar us full + assert len(bar.replace("█", "").replace("#", "").strip()) == 0, f"Bar not full: '{bar}'" +