# pylint:disable=missing-function-docstring import tempfile import uuid from copy import deepcopy from pathlib import Path from unittest.mock import Mock, patch import jinja2 import pytest from apocalypse.arghandler import get_apo_input from apocalypse.create_bash_script import (create_script, get_slurm_params, get_template) @pytest.fixture def mock_apo_input(temp_dir): mock = Mock(spec=get_apo_input) mock.script_file = "test_script.py" mock.endstation = "test_endstation" mock.slurm_job = Path(temp_dir).joinpath("ap_job.sh") mock.send_done_pth = Path(__file__).parent.joinpath("..", "..", "apocalypse", "send_done.py") mock.header = { "action": "write_finished" } mock.request = { "writer_type": "test_writer", "output_file": "/path/to/output.txt", "metadata": { "general/instrument": "test_endstation" } } mock.log_file = "/pth/to/somewhere" mock.slurm_params = [ "--cpus-per-task=5", "--mem=0", "--partition='prod'", "--exclusive", "--error='/pth/to/e'", "--output='/pth/to/o'", "--input='/pth/to/i'", "--job-name='job'" ] mock.slurm_params_other = [" --nice=100, --account='foo''"] mock.cpus_per_task = 5 mock.mem = "0" mock.partition = "prod" mock.exclusive = True mock.error_pth = "/pth/to/e" mock.input_pth = "/pth/to/i" mock.output_pth = "/pth/to/o" mock.job_name = "job" mock.template_file = ( Path(__file__).parent.joinpath("..", "..", "apocalypse", "job_template.jinja") ) mock.broker_url = "test-broker" mock.slurm_job_date = "_%Y_%m_%d" mock.writer_type = "test_writer" return mock @pytest.fixture def template(): template_file = Path(__file__).parent.joinpath("..", "..", "apocalypse", "job_template.jinja") return get_template(template_file) def test_get_template_type(template): assert isinstance(template, jinja2.environment.Template) def test_get_template_name(template): assert template.name == "job_template.jinja" def test_get_template_missing_template(): template_file = Path(__file__).parent.joinpath(f"does_not_exist_{uuid.uuid4()}") with pytest.raises(IOError): get_template(template_file) @pytest.fixture def temp_dir(): with tempfile.TemporaryDirectory() as directory: yield directory def test_file_created(mock_apo_input): create_script(mock_apo_input) assert mock_apo_input.slurm_job.is_file() # add others not just time @pytest.mark.parametrize("mem", ["12G", "0", "8000", "300M"]) def test_slurm_params(mock_apo_input, mem): mock_apo_input.mem = mem create_script(mock_apo_input) with open(mock_apo_input.slurm_job, "r") as f: file_content = f.read() assert f"#SBATCH --mem={mem}" in file_content @pytest.mark.parametrize("script", ["./test", "./dummy.sh", "./tests/ugly_file.m"]) def test_filename(mock_apo_input, script): mock_apo_input.script_file = script create_script(mock_apo_input) with open(mock_apo_input.slurm_job, "r") as f: file_content = f.read() assert f"script_file={script}" in file_content @pytest.mark.parametrize("broker", ["sf-daq-8", "sf-daq-11", "smth"]) def test_broker_url(mock_apo_input, broker): mock_apo_input.broker_url = broker create_script(mock_apo_input) with open(mock_apo_input.slurm_job, "r") as f: file_content = f.read() assert f"-b {broker}" in file_content def test_create_script_missing_template(mock_apo_input): params = deepcopy(mock_apo_input) params.template_file = Path(__file__).parent.joinpath(f"does_not_exist_{uuid.uuid4()}") with pytest.raises(IOError): create_script(params) def test_create_script_issue_with_main_folder(mock_apo_input): params = deepcopy(mock_apo_input) params.slurm_job = ( Path(__file__).parent.joinpath(f"/{uuid.uuid4()}/does_not_exist_{uuid.uuid4()}") ) with pytest.raises(IOError): create_script(params) def test_create_script_issue_with_writing(mock_apo_input): with pytest.raises(IOError), patch("builtins.open", side_effect=OSError("Permission denied")): create_script(mock_apo_input) def test_get_slurm_params(mock_apo_input): result = get_slurm_params(mock_apo_input) assert mock_apo_input.slurm_params.sort() == result.sort() params = deepcopy(mock_apo_input) attrs_to_mod = [ "input_pth", "output_pth", "error_pth", "exclusive", "partition", "mem", "cpus_per_task", "job_name" ] for attr in vars(params): if attr in attrs_to_mod: setattr(params, attr, None) assert not get_slurm_params(params) # test multiple path names? needed if already testing random folders?