diff --git a/docs/content/multiprocess/_index.md b/docs/content/multiprocess/_index.md index 42ea6a67..d6fd4a03 100644 --- a/docs/content/multiprocess/_index.md +++ b/docs/content/multiprocess/_index.md @@ -96,3 +96,45 @@ from prometheus_client import Gauge # Example gauge IN_PROGRESS = Gauge("inprogress_requests", "help", multiprocess_mode='livesum') ``` + +**5. Customizing metric values**: + +It's possible to customize the behavior of metric values by providing your own implementation of the `ValueClass`. This is useful if you want to add logging, custom synchronization, or change the data storage mechanism. + +The `MmapedValue` and `MutexValue` classes are available in `prometheus_client.values` for this purpose. These are top-level classes, which makes it easy to inherit from them and override their methods. + +To provide a custom `ValueClass`, set the `PROMETHEUS_VALUE_CLASS` environment variable to the full Python path of your class (e.g., `myapp.custom_values.MyValueClass`). + +The class should inherit from `prometheus_client.values.MutexValue` (for single-process applications) or `prometheus_client.values.MmapedValue` (for multiprocess applications) to reuse the existing logic. + +#### Example: Custom Mmaped Value + +If you're using multiprocess mode and want to override the default increment behavior: + +```python +# myapp/custom_values.py +from prometheus_client.values import MmapedValue + +class MyMmapedValue(MmapedValue): + def inc(self, amount): + print(f"Incrementing metric by {amount}") + # Always call the superclass method to ensure the value is + # correctly stored and shared state is handled. + super().inc(amount) +``` + +Then, set the environment variable: + +```bash +export PROMETHEUS_VALUE_CLASS=myapp.custom_values.MyMmapedValue +``` + +#### Behavior and Requirements: +- The environment variable must be set before any metric is instantiated. Therefore, preferrably, before python process start. +- The path must be a valid Python path to a class (including the class name). +- If the class cannot be imported, an `ImportError` will be raised during initialization. +- By default, `prometheus_client` uses `MmapedValue` if `PROMETHEUS_MULTIPROC_DIR` is set, and `MutexValue` otherwise. + +**6. Advanced Customization with `MultiProcessValue`**: + +For specialized use cases where you need a different process identifier than `os.getpid()`, you can use the `MultiProcessValue(process_identifier)` factory function. This returns a subclass of `MmapedValue` that uses the provided function to identify the process. Note that this cannot be set via the `PROMETHEUS_VALUE_CLASS` environment variable. diff --git a/prometheus_client/values.py b/prometheus_client/values.py index 6ff85e3b..a2e06754 100644 --- a/prometheus_client/values.py +++ b/prometheus_client/values.py @@ -1,3 +1,4 @@ +import importlib import os from threading import Lock import warnings @@ -36,6 +37,82 @@ def get_exemplar(self): return self._exemplar +class MmapedValue: + """A float protected by a mutex backed by a per-process mmaped file.""" + + _multiprocess = True + _files = {} + _values = [] + _pid = {'value': os.getpid()} + _lock = Lock() + _process_identifier = staticmethod(os.getpid) + + def __init__(self, typ, metric_name, name, labelnames, labelvalues, help_text, multiprocess_mode='', **kwargs): + self._params = typ, metric_name, name, labelnames, labelvalues, help_text, multiprocess_mode + # This deprecation warning can go away in a few releases when removing the compatibility + if 'prometheus_multiproc_dir' in os.environ and 'PROMETHEUS_MULTIPROC_DIR' not in os.environ: + os.environ['PROMETHEUS_MULTIPROC_DIR'] = os.environ['prometheus_multiproc_dir'] + warnings.warn("prometheus_multiproc_dir variable has been deprecated in favor of the upper case naming PROMETHEUS_MULTIPROC_DIR", DeprecationWarning) + with self._lock: + self.__check_for_pid_change() + self.__reset() + self._values.append(self) + + def __reset(self): + typ, metric_name, name, labelnames, labelvalues, help_text, multiprocess_mode = self._params + if typ == 'gauge': + file_prefix = typ + '_' + multiprocess_mode + else: + file_prefix = typ + if file_prefix not in self._files: + filename = os.path.join( + os.environ.get('PROMETHEUS_MULTIPROC_DIR'), + '{}_{}.db'.format(file_prefix, self._pid['value'])) + + self._files[file_prefix] = MmapedDict(filename) + self._file = self._files[file_prefix] + self._key = mmap_key(metric_name, name, labelnames, labelvalues, help_text) + self._value, self._timestamp = self._file.read_value(self._key) + + def __check_for_pid_change(self): + actual_pid = self._process_identifier() + if self._pid['value'] != actual_pid: + self._pid['value'] = actual_pid + # There has been a fork(), reset all the values. + for f in self._files.values(): + f.close() + self._files.clear() + for value in self._values: + value.__reset() + + def inc(self, amount): + with self._lock: + self.__check_for_pid_change() + self._value += amount + self._timestamp = 0.0 + self._file.write_value(self._key, self._value, self._timestamp) + + def set(self, value, timestamp=None): + with self._lock: + self.__check_for_pid_change() + self._value = value + self._timestamp = timestamp or 0.0 + self._file.write_value(self._key, self._value, self._timestamp) + + def set_exemplar(self, exemplar): + # TODO: Implement exemplars for multiprocess mode. + return + + def get(self): + with self._lock: + self.__check_for_pid_change() + return self._value + + def get_exemplar(self): + # TODO: Implement exemplars for multiprocess mode. + return None + + def MultiProcessValue(process_identifier=os.getpid): """Returns a MmapedValue class based on a process_identifier function. @@ -44,85 +121,14 @@ def MultiProcessValue(process_identifier=os.getpid): Using a different function than the default 'os.getpid' is at your own risk. """ - files = {} - values = [] - pid = {'value': process_identifier()} - # Use a single global lock when in multi-processing mode - # as we presume this means there is no threading going on. - # This avoids the need to also have mutexes in __MmapDict. - lock = Lock() - - class MmapedValue: - """A float protected by a mutex backed by a per-process mmaped file.""" - - _multiprocess = True - - def __init__(self, typ, metric_name, name, labelnames, labelvalues, help_text, multiprocess_mode='', **kwargs): - self._params = typ, metric_name, name, labelnames, labelvalues, help_text, multiprocess_mode - # This deprecation warning can go away in a few releases when removing the compatibility - if 'prometheus_multiproc_dir' in os.environ and 'PROMETHEUS_MULTIPROC_DIR' not in os.environ: - os.environ['PROMETHEUS_MULTIPROC_DIR'] = os.environ['prometheus_multiproc_dir'] - warnings.warn("prometheus_multiproc_dir variable has been deprecated in favor of the upper case naming PROMETHEUS_MULTIPROC_DIR", DeprecationWarning) - with lock: - self.__check_for_pid_change() - self.__reset() - values.append(self) - - def __reset(self): - typ, metric_name, name, labelnames, labelvalues, help_text, multiprocess_mode = self._params - if typ == 'gauge': - file_prefix = typ + '_' + multiprocess_mode - else: - file_prefix = typ - if file_prefix not in files: - filename = os.path.join( - os.environ.get('PROMETHEUS_MULTIPROC_DIR'), - '{}_{}.db'.format(file_prefix, pid['value'])) - - files[file_prefix] = MmapedDict(filename) - self._file = files[file_prefix] - self._key = mmap_key(metric_name, name, labelnames, labelvalues, help_text) - self._value, self._timestamp = self._file.read_value(self._key) - - def __check_for_pid_change(self): - actual_pid = process_identifier() - if pid['value'] != actual_pid: - pid['value'] = actual_pid - # There has been a fork(), reset all the values. - for f in files.values(): - f.close() - files.clear() - for value in values: - value.__reset() - - def inc(self, amount): - with lock: - self.__check_for_pid_change() - self._value += amount - self._timestamp = 0.0 - self._file.write_value(self._key, self._value, self._timestamp) - - def set(self, value, timestamp=None): - with lock: - self.__check_for_pid_change() - self._value = value - self._timestamp = timestamp or 0.0 - self._file.write_value(self._key, self._value, self._timestamp) - - def set_exemplar(self, exemplar): - # TODO: Implement exemplars for multiprocess mode. - return - - def get(self): - with lock: - self.__check_for_pid_change() - return self._value - - def get_exemplar(self): - # TODO: Implement exemplars for multiprocess mode. - return None - - return MmapedValue + class _MmapedValue(MmapedValue): + _files = {} + _values = [] + _pid = {'value': process_identifier()} + _lock = Lock() + _process_identifier = staticmethod(process_identifier) + + return _MmapedValue def get_value_class(): @@ -130,10 +136,20 @@ def get_value_class(): # This needs to be chosen before the first metric is constructed, # and as that may be in some arbitrary library the user/admin has # no control over we use an environment variable. + value_class_path = os.environ.get('PROMETHEUS_VALUE_CLASS') + if value_class_path: + if '.' not in value_class_path: + raise ImportError(f"PROMETHEUS_VALUE_CLASS must be a full python path (e.g. module.ClassName), got '{value_class_path}'") + try: + module_path, class_name = value_class_path.rsplit('.', 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ImportError(f"Could not import PROMETHEUS_VALUE_CLASS '{value_class_path}': {e}") from None + if 'prometheus_multiproc_dir' in os.environ or 'PROMETHEUS_MULTIPROC_DIR' in os.environ: - return MultiProcessValue() - else: - return MutexValue + return MmapedValue + return MutexValue ValueClass = get_value_class() diff --git a/tests/e2e/server.py b/tests/e2e/server.py new file mode 100644 index 00000000..52aef419 --- /dev/null +++ b/tests/e2e/server.py @@ -0,0 +1,90 @@ +import os +import sys + +# Use LockedMmapedValue for this server +os.environ['PROMETHEUS_VALUE_CLASS'] = 'prometheus_client.values.LockedMmapedValue' + +import http.server +import json +from urllib.parse import urlparse, parse_qs +from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, generate_latest, values +from prometheus_client.multiprocess import MultiProcessAggregateCollector + +# Define metrics at module level +C = Counter('c', 'test counter', ['l']) +G_SUM = Gauge('g_sum', 'test gauge sum', ['l'], multiprocess_mode='sum') +G_MAX = Gauge('g_max', 'test gauge max', ['l'], multiprocess_mode='max') +G_MIN = Gauge('g_min', 'test gauge min', ['l'], multiprocess_mode='min') +G_MOSTRECENT = Gauge('g_mostrecent', 'test gauge mostrecent', ['l'], multiprocess_mode='mostrecent') +G_ALL = Gauge('g_all', 'test gauge all', ['l'], multiprocess_mode='all') +G_LIVESUM = Gauge('g_livesum', 'test gauge livesum', ['l'], multiprocess_mode='livesum') +G_LIVEMAX = Gauge('g_livemax', 'test gauge livemax', ['l'], multiprocess_mode='livemax') +G_LIVEMIN = Gauge('g_livemin', 'test gauge livemin', ['l'], multiprocess_mode='livemin') +G_LIVEMOSTRECENT = Gauge('g_livemostrecent', 'test gauge livemostrecent', ['l'], multiprocess_mode='livemostrecent') +G_LIVEALL = Gauge('g_liveall', 'test gauge liveall', ['l'], multiprocess_mode='liveall') +H = Histogram('h', 'test histogram', ['l'], buckets=(1.0, 5.0, 10.0)) + +METRICS = { + 'c': C, + 'g_sum': G_SUM, + 'g_max': G_MAX, + 'g_min': G_MIN, + 'g_mostrecent': G_MOSTRECENT, + 'g_all': G_ALL, + 'g_livesum': G_LIVESUM, + 'g_livemax': G_LIVEMAX, + 'g_livemin': G_LIVEMIN, + 'g_livemostrecent': G_LIVEMOSTRECENT, + 'g_liveall': G_LIVEALL, + 'h': H, +} + +class MetricHandler(http.server.BaseHTTPRequestHandler): + def send_ok(self, data=b'OK', content_type='text/plain'): + self.send_response(200) + self.send_header('Content-Type', content_type) + self.end_headers() + self.wfile.write(data) + + def send_error(self, code=404): + self.send_response(code) + self.end_headers() + + def do_GET(self): + parsed_url = urlparse(self.path) + query = parse_qs(parsed_url.query) + path = parsed_url.path + + if path == '/metrics': + registry = CollectorRegistry() + MultiProcessAggregateCollector(registry) + self.send_ok(generate_latest(registry)) + elif path in ('/inc', '/set', '/observe'): + name = query.get('name', [None])[0] + labels_json = query.get('labels', ['{}'])[0] + labels = json.loads(labels_json) + value = float(query.get('value', query.get('amount', [1]))[0]) + + if name not in METRICS: + self.send_error(400) + return + + m = METRICS[name] + metric_with_labels = m.labels(**labels) if labels else m + + if path == '/inc': + metric_with_labels.inc(value) + elif path == '/set': + metric_with_labels.set(value) + elif path == '/observe': + metric_with_labels.observe(value) + + self.send_ok() + else: + self.send_error() + +if __name__ == '__main__': + port = int(sys.argv[1]) + server = http.server.HTTPServer(('127.0.0.1', port), MetricHandler) + print(f'Starting server on port {port}') + server.serve_forever() diff --git a/tests/e2e/test_multi_process.py b/tests/e2e/test_multi_process.py new file mode 100644 index 00000000..4c906894 --- /dev/null +++ b/tests/e2e/test_multi_process.py @@ -0,0 +1,280 @@ +import os +import sys +import subprocess +import time +import unittest +import shutil +import urllib.request +import tempfile +import json +from prometheus_client.parser import text_string_to_metric_families + +class TestMultiProcessAggregate(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + os.environ['PROMETHEUS_MULTIPROC_DIR'] = self.tmpdir + self.processes = [] + + def tearDown(self): + for p in self.processes: + p.terminate() + p.wait() + shutil.rmtree(self.tmpdir) + + def start_server(self, port): + # We need to make sure prometheus_client is in PYTHONPATH + env = os.environ.copy() + env['PYTHONPATH'] = os.getcwd() + print(f"DEBUG: Starting server on port {port} with PROMETHEUS_MULTIPROC_DIR={env.get('PROMETHEUS_MULTIPROC_DIR')}") + p = subprocess.Popen([sys.executable, 'tests/e2e/server.py', str(port)], env=env) + self.processes.append(p) + # Wait for server to start + max_retries = 10 + for i in range(max_retries): + try: + urllib.request.urlopen(f'http://127.0.0.1:{port}/metrics', timeout=1) + break + except: + time.sleep(0.5) + else: + self.fail(f"Server on port {port} failed to start") + return p + + def get_metrics(self, port): + content = urllib.request.urlopen(f'http://127.0.0.1:{port}/metrics').read().decode() + print(f"DEBUG: Metrics from {port}:\n{content}") + families = text_string_to_metric_families(content) + metrics = {} + for family in families: + for sample in family.samples: + # Store by (name, labels_tuple) + labels = tuple(sorted(sample.labels.items())) + metrics[(sample.name, labels)] = sample.value + return metrics + + def call_metric(self, port, action, name, labels, value): + import urllib.parse + labels_json = json.dumps(labels) + labels_encoded = urllib.parse.quote(labels_json) + url = f'http://127.0.0.1:{port}/{action}?name={name}&labels={labels_encoded}&value={value}' + return urllib.request.urlopen(url).read() + + def test_aggregation_and_modes(self): + port1 = 12345 + port2 = 12346 + + # Start two servers + p1 = self.start_server(port1) + p2 = self.start_server(port2) + + labels = '{"l": "v"}' + labels_dict = {"l": "v"} + labels_tuple = (("l", "v"),) + + import urllib.parse + labels_encoded = urllib.parse.quote(labels) + + # 1. Test Counters (Aggregate by sum) + self.call_metric(port1, 'inc', 'c', labels_dict, 10) + time.sleep(1) + self.call_metric(port2, 'inc', 'c', labels_dict, 20) + time.sleep(1) + + # 2. Test Gauges (Various modes) + # sum + self.call_metric(port1, 'set', 'g_sum', labels_dict, 10) + self.call_metric(port2, 'set', 'g_sum', labels_dict, 20) + # max + self.call_metric(port1, 'set', 'g_max', labels_dict, 10) + self.call_metric(port2, 'set', 'g_max', labels_dict, 20) + # min + self.call_metric(port1, 'set', 'g_min', labels_dict, 10) + self.call_metric(port2, 'set', 'g_min', labels_dict, 20) + # mostrecent + self.call_metric(port1, 'set', 'g_mostrecent', labels_dict, 10) + time.sleep(0.1) # Ensure different timestamp if possible (though mmap might not have high res) + self.call_metric(port2, 'set', 'g_mostrecent', labels_dict, 20) + + # 3. Test Histograms + self.call_metric(port1, 'observe', 'h', labels_dict, 2) + self.call_metric(port2, 'observe', 'h', labels_dict, 6) + + # Check metrics while both are alive + m = self.get_metrics(port1) + + # Verify .db files exist for both processes + # p1's files + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'counter_{p1.pid}.db'))) + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'gauge_sum_{p1.pid}.db'))) + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'histogram_{p1.pid}.db'))) + # p2's files + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'counter_{p2.pid}.db'))) + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'gauge_sum_{p2.pid}.db'))) + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'histogram_{p2.pid}.db'))) + + self.assertEqual(m[('c_total', labels_tuple)], 30.0) + self.assertEqual(m[('g_sum', labels_tuple)], 30.0) + self.assertEqual(m[('g_max', labels_tuple)], 20.0) + self.assertEqual(m[('g_min', labels_tuple)], 10.0) + self.assertEqual(m[('g_mostrecent', labels_tuple)], 20.0) + self.assertEqual(m[('h_count', labels_tuple)], 2.0) + self.assertEqual(m[('h_sum', labels_tuple)], 8.0) + self.assertEqual(m[('h_bucket', labels_tuple + (('le', '1.0'),))], 0.0) + self.assertEqual(m[('h_bucket', labels_tuple + (('le', '5.0'),))], 1.0) + self.assertEqual(m[('h_bucket', labels_tuple + (('le', '10.0'),))], 2.0) + self.assertEqual(m[('h_bucket', labels_tuple + (('le', '+Inf'),))], 2.0) + + # Kill port2 server + p2.terminate() + p2.wait() + self.processes.remove(p2) + + # Check metrics from surviving server (should be aggregated) + m = self.get_metrics(port1) + + # Verify p2's .db files are gone after collection + self.assertFalse(os.path.exists(os.path.join(self.tmpdir, f'counter_{p2.pid}.db'))) + self.assertFalse(os.path.exists(os.path.join(self.tmpdir, f'gauge_sum_{p2.pid}.db'))) + self.assertFalse(os.path.exists(os.path.join(self.tmpdir, f'histogram_{p2.pid}.db'))) + # p1's files should still exist + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'counter_{p1.pid}.db'))) + + self.assertEqual(m[('c_total', labels_tuple)], 30.0) + self.assertEqual(m[('g_sum', labels_tuple)], 30.0) + self.assertEqual(m[('g_max', labels_tuple)], 20.0) + self.assertEqual(m[('g_min', labels_tuple)], 10.0) + self.assertEqual(m[('g_mostrecent', labels_tuple)], 20.0) + self.assertEqual(m[('h_count', labels_tuple)], 2.0) + self.assertEqual(m[('h_sum', labels_tuple)], 8.0) + self.assertEqual(m[('h_bucket', labels_tuple + (('le', '1.0'),))], 0.0) + self.assertEqual(m[('h_bucket', labels_tuple + (('le', '5.0'),))], 1.0) + self.assertEqual(m[('h_bucket', labels_tuple + (('le', '10.0'),))], 2.0) + self.assertEqual(m[('h_bucket', labels_tuple + (('le', '+Inf'),))], 2.0) + + # Ensure aggregate.db exists + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, 'aggregate.db'))) + + # Kill surviving server + p1.terminate() + p1.wait() + self.processes.remove(p1) + + # Start new server p3 + port3 = 12347 + p3 = self.start_server(port3) + + # Check metrics from p3 (should read from aggregate.db) + m = self.get_metrics(port3) + self.assertEqual(m[('c_total', labels_tuple)], 30.0) + self.assertEqual(m[('g_sum', labels_tuple)], 30.0) + self.assertEqual(m[('g_max', labels_tuple)], 20.0) + self.assertEqual(m[('g_min', labels_tuple)], 10.0) + self.assertEqual(m[('g_mostrecent', labels_tuple)], 20.0) + self.assertEqual(m[('h_count', labels_tuple)], 2.0) + + # Add more to p3 + self.call_metric(port3, 'inc', 'c', labels_dict, 5) + m = self.get_metrics(port3) + self.assertEqual(m[('c_total', labels_tuple)], 35.0) + + def test_live_gauges(self): + # Test various live gauge modes + modes = { + 'g_livesum': 30.0, + 'g_livemax': 20.0, + 'g_livemin': 10.0, + 'g_livemostrecent': 20.0, + } + + for name, expected_sum in modes.items(): + port1 = 12348 + port2 = 12349 + p1 = self.start_server(port1) + p2 = self.start_server(port2) + + labels_dict = {"l": "live"} + labels_tuple = (("l", "live"),) + + # Set live gauges + self.call_metric(port1, 'set', name, labels_dict, 10) + if name == 'g_livemostrecent': + time.sleep(0.1) + self.call_metric(port2, 'set', name, labels_dict, 20) + + m = self.get_metrics(port1) + self.assertEqual(m[(name, labels_tuple)], expected_sum, f"Failed for {name} with both alive") + + # Verify .db files exist + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'gauge_{name[2:]}_{p1.pid}.db'))) + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'gauge_{name[2:]}_{p2.pid}.db'))) + + # Kill p2 + p2.terminate() + p2.wait() + self.processes.remove(p2) + + # Live gauge should only reflect p1 now, as p2 is dead and live gauges are not aggregated into aggregate.db + m = self.get_metrics(port1) + + # Verify p2's live gauge .db file is gone + self.assertFalse(os.path.exists(os.path.join(self.tmpdir, f'gauge_{name[2:]}_{p2.pid}.db'))) + # p1's should still exist + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'gauge_{name[2:]}_{p1.pid}.db'))) + + self.assertEqual(m[(name, labels_tuple)], 10.0, f"Failed for {name} after p2 death") + + # Cleanup for next iteration + p1.terminate() + p1.wait() + self.processes.remove(p1) + shutil.rmtree(self.tmpdir) + self.tmpdir = tempfile.mkdtemp() + os.environ['PROMETHEUS_MULTIPROC_DIR'] = self.tmpdir + + def test_live_all_gauge(self): + # Test liveall mode separately as it keeps PID labels + port1 = 12350 + port2 = 12351 + p1 = self.start_server(port1) + p2 = self.start_server(port2) + + pid1 = str(p1.pid) + pid2 = str(p2.pid) + + labels_dict = {"l": "liveall"} + + self.call_metric(port1, 'set', 'g_liveall', labels_dict, 10) + self.call_metric(port2, 'set', 'g_liveall', labels_dict, 20) + + m = self.get_metrics(port1) + # Should have two entries with different pids + + # Verify .db files exist + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'gauge_liveall_{p1.pid}.db'))) + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'gauge_liveall_{p2.pid}.db'))) + + expected_metrics = { + (('l', 'liveall'), ('pid', pid1)): 10.0, + (('l', 'liveall'), ('pid', pid2)): 20.0, + } + for labels, val in expected_metrics.items(): + self.assertEqual(m.get(('g_liveall', labels)), val, f"Missing or incorrect metric for pid {labels}") + + # Kill p2 + p2.terminate() + p2.wait() + self.processes.remove(p2) + + # Now should only have one entry (p1's) + m = self.get_metrics(port1) + + # Verify p2's liveall .db file is gone + self.assertFalse(os.path.exists(os.path.join(self.tmpdir, f'gauge_liveall_{p2.pid}.db'))) + # p1's should still exist + self.assertTrue(os.path.exists(os.path.join(self.tmpdir, f'gauge_liveall_{p1.pid}.db'))) + + self.assertEqual(m.get(('g_liveall', (('l', 'liveall'), ('pid', pid1)))), 10.0) + self.assertNotIn(('g_liveall', (('l', 'liveall'), ('pid', pid2))), m) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_multiprocess.py b/tests/test_multiprocess.py index ee0c7423..4898e83b 100644 --- a/tests/test_multiprocess.py +++ b/tests/test_multiprocess.py @@ -1,4 +1,5 @@ import glob +import importlib import os import shutil import tempfile @@ -646,3 +647,70 @@ def test_file_syncpath(self): def tearDown(self): os.remove(self.tmpfl) + + +class TestCustomValueClass(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + os.environ['PROMETHEUS_MULTIPROC_DIR'] = self.tmpdir + + def tearDown(self): + shutil.rmtree(self.tmpdir) + # Restore default ValueClass + values.ValueClass = values.get_value_class() + + def test_custom_value_class(self): + class MyCustomValue(values.MmapedValue): + def inc(self, amount): + super().inc(amount * 2) # Double the increment + + values.ValueClass = MyCustomValue + + c = Counter('my_counter', 'help') + c.inc(5) + + self.assertEqual(c._value.get(), 10) + self.assertIsInstance(c._value, MyCustomValue) + + +class TestValueClassEnv(unittest.TestCase): + def setUp(self): + self.original_env = os.environ.get('PROMETHEUS_VALUE_CLASS') + if 'PROMETHEUS_VALUE_CLASS' in os.environ: + del os.environ['PROMETHEUS_VALUE_CLASS'] + + def tearDown(self): + if self.original_env: + os.environ['PROMETHEUS_VALUE_CLASS'] = self.original_env + elif 'PROMETHEUS_VALUE_CLASS' in os.environ: + del os.environ['PROMETHEUS_VALUE_CLASS'] + # Reset ValueClass to default for other tests + importlib.reload(values) + + def test_default_value_class(self): + importlib.reload(values) + self.assertEqual(values.ValueClass, values.MutexValue) + + def test_multiproc_value_class(self): + os.environ['PROMETHEUS_MULTIPROC_DIR'] = '/tmp' + importlib.reload(values) + self.assertEqual(values.ValueClass, values.MmapedValue) + del os.environ['PROMETHEUS_MULTIPROC_DIR'] + + def test_env_var_value_class(self): + # We need a class to point to. Let's use MutexValue itself but via string + os.environ['PROMETHEUS_VALUE_CLASS'] = 'prometheus_client.values.MutexValue' + importlib.reload(values) + self.assertEqual(values.ValueClass, values.MutexValue) + + def test_invalid_path_fails_loudly(self): + os.environ['PROMETHEUS_VALUE_CLASS'] = 'invalid.path.Class' + with self.assertRaises(ImportError) as cm: + importlib.reload(values) + self.assertIn("Could not import PROMETHEUS_VALUE_CLASS", str(cm.exception)) + + def test_no_dot_fails_loudly(self): + os.environ['PROMETHEUS_VALUE_CLASS'] = 'NoDot' + with self.assertRaises(ImportError) as cm: + importlib.reload(values) + self.assertIn("PROMETHEUS_VALUE_CLASS must be a full python path", str(cm.exception))