Skip to content

Commit 4feef98

Browse files
committed
add: testing suite related to instrumentation
1 parent 85c4354 commit 4feef98

4 files changed

Lines changed: 415 additions & 0 deletions

File tree

tests/test_dynamic_policy.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
import shutil
3+
import sys
4+
import tempfile
5+
import unittest
6+
from unittest.mock import MagicMock, patch
7+
8+
import torch
9+
10+
# Ensure traincheck is in path
11+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
12+
13+
import traincheck.config.config as config
14+
from traincheck.instrumentor.caches import META_VARS
15+
from traincheck.instrumentor.tracer import Instrumentor
16+
17+
18+
class TestDynamicPolicy(unittest.TestCase):
19+
def setUp(self):
20+
META_VARS["step"] = 0
21+
META_VARS["stage"] = "training"
22+
config.INSTRUMENTATION_POLICY = None
23+
self.test_dir = tempfile.mkdtemp()
24+
os.environ["TRAINCHECK_OUTPUT_DIR"] = self.test_dir
25+
26+
def tearDown(self):
27+
if os.path.exists(self.test_dir):
28+
shutil.rmtree(self.test_dir)
29+
30+
@patch("traincheck.config.config.DISABLE_WRAPPER", new=False)
31+
def test_sampling_interval(self):
32+
# Setup policy
33+
from traincheck.config import config
34+
from traincheck.instrumentor.control import start_step
35+
36+
config.INSTRUMENTATION_POLICY = {"interval": 2, "warm_up": 0}
37+
38+
start_step() # Step 1: (1-0)%2 != 0 -> Disabled
39+
self.assertTrue(config.DISABLE_WRAPPER, "Step 1 should be disabled")
40+
41+
start_step() # Step 2: (2-0)%2 == 0 -> Enabled
42+
self.assertFalse(config.DISABLE_WRAPPER, "Step 2 should be enabled")
43+
44+
start_step() # Step 3: Disabled
45+
self.assertTrue(config.DISABLE_WRAPPER, "Step 3 should be disabled")
46+
47+
start_step() # Step 4: Enabled
48+
self.assertFalse(config.DISABLE_WRAPPER, "Step 4 should be enabled")
49+
50+
@patch("traincheck.config.config.DISABLE_WRAPPER", new=False)
51+
def test_warmup(self):
52+
# Setup policy
53+
from traincheck.config import config
54+
from traincheck.instrumentor.control import start_step
55+
56+
config.INSTRUMENTATION_POLICY = {"interval": 10, "warm_up": 2}
57+
58+
start_step() # Step 1. Warmup.
59+
self.assertFalse(config.DISABLE_WRAPPER, "Step 1 (warmup) should be enabled")
60+
61+
start_step() # Step 2. Warmup.
62+
self.assertFalse(config.DISABLE_WRAPPER, "Step 2 (warmup) should be enabled")
63+
64+
start_step() # Step 3. (3-2)%10 != 0 -> Disabled.
65+
self.assertTrue(config.DISABLE_WRAPPER, "Step 3 should be disabled")
66+
67+
# Fast forward to step 12
68+
for _ in range(9):
69+
start_step()
70+
71+
# Step 12. (12-2)%10 == 0 -> Enabled.
72+
self.assertFalse(config.DISABLE_WRAPPER, "Step 12 should be enabled")
73+
74+
def test_stage_change_resets_wrapper(self):
75+
# Simulate being in a "skip" state
76+
config.DISABLE_WRAPPER = True
77+
META_VARS["stage"] = "training"
78+
79+
from traincheck.developer.annotations import annotate_stage
80+
81+
# Change stage to evaluation
82+
annotate_stage("evaluation")
83+
84+
# Should be enabled now
85+
self.assertFalse(
86+
config.DISABLE_WRAPPER, "DISABLE_WRAPPER should be False after stage change"
87+
)
88+
self.assertEqual(META_VARS["stage"], "evaluation")
89+
90+
# Change back to training
91+
config.DISABLE_WRAPPER = True # Simulate it was somehow disabled again
92+
annotate_stage("training")
93+
self.assertFalse(
94+
config.DISABLE_WRAPPER,
95+
"DISABLE_WRAPPER should be False after entering training",
96+
)
97+
98+
99+
if __name__ == "__main__":
100+
unittest.main()

tests/test_loop_injection.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import ast
2+
import unittest
3+
4+
from traincheck.instrumentor.source_file import InsertTracerVisitor
5+
6+
7+
class TestLoopInjection(unittest.TestCase):
8+
def test_inject_training_loop(self):
9+
source = """
10+
import torch
11+
def train():
12+
for i in range(10):
13+
data = get_data()
14+
optimizer.step()
15+
"""
16+
17+
visitor = InsertTracerVisitor(
18+
modules_to_instr=["torch"],
19+
scan_proxy_in_args=False,
20+
use_full_instr=False,
21+
funcs_to_instr=None,
22+
API_dump_stack_trace=False,
23+
sampling_interval=1,
24+
warm_up_steps=0,
25+
)
26+
27+
tree = ast.parse(source)
28+
visitor.visit(tree)
29+
30+
new_source = ast.unparse(tree)
31+
new_source = ast.unparse(tree)
32+
33+
self.assertIn("import start_step", new_source)
34+
self.assertIn("start_step()", new_source)
35+
self.assertIn("from traincheck.instrumentor.control", new_source)
36+
37+
self.assertIn("start_step", new_source)
38+
39+
def test_inject_eval_loop(self):
40+
source = """
41+
import torch
42+
def test():
43+
for i in range(10):
44+
print(i)
45+
"""
46+
47+
visitor = InsertTracerVisitor(
48+
modules_to_instr=["torch"],
49+
scan_proxy_in_args=False,
50+
use_full_instr=False,
51+
funcs_to_instr=None,
52+
API_dump_stack_trace=False,
53+
sampling_interval=1,
54+
warm_up_steps=0,
55+
)
56+
57+
tree = ast.parse(source)
58+
visitor.visit(tree)
59+
60+
new_source = ast.unparse(tree)
61+
new_source = ast.unparse(tree)
62+
63+
# Should detect "test" function name and inject start_eval_step
64+
self.assertIn("import start_eval_step", new_source)
65+
self.assertIn("start_eval_step()", new_source)
66+
67+
def test_no_inject_irrelevant_loop(self):
68+
source = """
69+
def check_thing():
70+
for i in range(10):
71+
print(i)
72+
"""
73+
74+
visitor = InsertTracerVisitor(
75+
modules_to_instr=["torch"],
76+
scan_proxy_in_args=False,
77+
use_full_instr=False,
78+
funcs_to_instr=None,
79+
API_dump_stack_trace=False,
80+
sampling_interval=1,
81+
warm_up_steps=0,
82+
)
83+
84+
tree = ast.parse(source)
85+
visitor.visit(tree)
86+
87+
new_source = ast.unparse(tree)
88+
89+
self.assertNotIn("start_step", new_source)
90+
self.assertNotIn("start_eval_step", new_source)
91+
92+
93+
if __name__ == "__main__":
94+
unittest.main()

tests/test_policy_injection.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import os
2+
import shutil
3+
import sys
4+
import tempfile
5+
import unittest
6+
from unittest.mock import MagicMock, patch
7+
8+
# Ensure traincheck is in path (standard pattern for this repo based on other tests)
9+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
10+
11+
import traincheck.collect_trace as collect_trace
12+
from traincheck.config import config
13+
14+
15+
class TestPolicyInjection(unittest.TestCase):
16+
def setUp(self):
17+
self.test_dir = tempfile.mkdtemp()
18+
self.dummy_script = os.path.join(self.test_dir, "dummy_script.py")
19+
with open(self.dummy_script, "w") as f:
20+
f.write("import torch\n")
21+
22+
def tearDown(self):
23+
shutil.rmtree(self.test_dir)
24+
25+
@patch("traincheck.runner.ProgramRunner")
26+
@patch(
27+
"traincheck.instrumentor.instrument_file",
28+
side_effect=collect_trace.instrumentor.instrument_file,
29+
)
30+
def test_policy_injection(self, mock_instrument_file, MockProgramRunner):
31+
# Setup mock runner
32+
mock_runner_instance = MockProgramRunner.return_value
33+
mock_runner_instance.run.return_value = ("output", 0)
34+
35+
# Simulate command line arguments
36+
test_args = [
37+
"collect_trace.py",
38+
"-p",
39+
self.dummy_script,
40+
"--sampling-interval",
41+
"3",
42+
"--warm-up-steps",
43+
"2",
44+
"--only-instr",
45+
]
46+
47+
with patch.object(sys, "argv", test_args):
48+
collect_trace.main()
49+
50+
# Check if instrument_file was called with the correct args
51+
call_args = mock_instrument_file.call_args
52+
self.assertIsNotNone(call_args, "instrument_file was not called")
53+
_, kwargs = call_args
54+
self.assertEqual(kwargs.get("sampling_interval"), 3)
55+
self.assertEqual(kwargs.get("warm_up_steps"), 2)
56+
57+
# Check if the policy ends up in instrumented source code
58+
runner_call_args = MockProgramRunner.call_args
59+
self.assertIsNotNone(runner_call_args, "ProgramRunner was not initialized")
60+
source_code = runner_call_args[0][0] # first arg is source_code
61+
62+
self.assertIn("sampling_interval=3", source_code)
63+
self.assertIn("warm_up_steps=2", source_code)
64+
65+
@patch("traincheck.runner.ProgramRunner")
66+
@patch(
67+
"traincheck.instrumentor.instrument_file",
68+
side_effect=collect_trace.instrumentor.instrument_file,
69+
)
70+
@patch("traincheck.collect_trace.read_inv_file")
71+
def test_defaults_with_invariant(
72+
self, mock_read_inv, mock_instrument_file, MockProgramRunner
73+
):
74+
# Setup mocks
75+
mock_runner_instance = MockProgramRunner.return_value
76+
mock_runner_instance.run.return_value = ("output", 0)
77+
mock_read_inv.return_value = [] # Return empty list of invariants
78+
79+
test_args = [
80+
"collect_trace.py",
81+
"-p",
82+
self.dummy_script,
83+
"-i",
84+
"dummy_inv.json", # Enable invariants
85+
"--only-instr",
86+
]
87+
88+
with patch.object(sys, "argv", test_args):
89+
collect_trace.main()
90+
91+
call_args = mock_instrument_file.call_args
92+
_, kwargs = call_args
93+
94+
# Should default to config values
95+
expected_interval = config.DEFAULT_CHECKING_POLICY["interval"]
96+
expected_warmup = config.DEFAULT_CHECKING_POLICY["warm_up"]
97+
98+
self.assertEqual(kwargs.get("sampling_interval"), expected_interval)
99+
self.assertEqual(kwargs.get("warm_up_steps"), expected_warmup)
100+
101+
@patch("traincheck.runner.ProgramRunner")
102+
@patch(
103+
"traincheck.instrumentor.instrument_file",
104+
side_effect=collect_trace.instrumentor.instrument_file,
105+
)
106+
def test_defaults_without_invariant(self, mock_instrument_file, MockProgramRunner):
107+
# Setup mocks
108+
mock_runner_instance = MockProgramRunner.return_value
109+
mock_runner_instance.run.return_value = ("output", 0)
110+
111+
test_args = ["collect_trace.py", "-p", self.dummy_script, "--only-instr"]
112+
113+
with patch.object(sys, "argv", test_args):
114+
collect_trace.main()
115+
116+
call_args = mock_instrument_file.call_args
117+
_, kwargs = call_args
118+
119+
# Should default to config.INSTRUMENTATION_POLICY values
120+
expected_interval = config.INSTRUMENTATION_POLICY["interval"]
121+
expected_warmup = config.INSTRUMENTATION_POLICY["warm_up"]
122+
123+
self.assertEqual(kwargs.get("sampling_interval"), expected_interval)
124+
self.assertEqual(kwargs.get("warm_up_steps"), expected_warmup)
125+
126+
127+
if __name__ == "__main__":
128+
unittest.main()

0 commit comments

Comments
 (0)