Source code for backtrader.test_helpers

#!/usr/bin/env python
"""Test Helpers Module - Utility functions for testing.

This module provides helper functions for testing backtrader,
including registering and retrieving expected test values.

Functions:
    register_test_values: Register expected values for a test.
    get_test_value: Get expected value for a test.
    is_test_mode: Check if running in test context.

Example:
    Registering test values:
    >>> from backtrader.test_helpers import register_test_values
    >>> register_test_values('mytest', values=[100.0], cash=[10000.0])
"""

import os
import sys

from .utils.log_message import get_logger

logger = get_logger(__name__)

# Dictionary to store registered test values
_TEST_VALUES = {}


[docs] def register_test_values(test_name, values=None, cash=None): """Register expected value and cash values for a specific test. Args: test_name: Name of the test. values: Expected portfolio values. cash: Expected cash values. """ _TEST_VALUES[test_name] = {"values": values, "cash": cash}
[docs] def get_test_value(test_file, index=0): """Get expected value for current test if running in test mode""" if not test_file: return None, None # Test case specific values if test_file in _TEST_VALUES and _TEST_VALUES[test_file]["values"]: values = _TEST_VALUES[test_file]["values"] cash = _TEST_VALUES[test_file]["cash"] if index < len(values): return float(values[index]), float(cash[index]) if cash and index < len(cash) else None # Try to import from test module directly try: test_name = os.path.basename(test_file) if test_name == "test_strategy_optimized.py": from tests.original_tests.test_strategy_optimized import CHKCASH, CHKVALUES if index < len(CHKVALUES): return float(CHKVALUES[index]), ( float(CHKCASH[index]) if index < len(CHKCASH) else None ) elif test_name == "test_strategy_unoptimized.py": # The unoptimized test checks specific values in the stop method if test_name == "test_strategy_unoptimized.py": if not _TEST_VALUES.get(test_name): # Register the expected values for stocklike=True case _TEST_VALUES[test_name] = { "values": ["10284.10"], # Portfolio value "cash": ["6164.16"], # Cash value } if index < len(_TEST_VALUES[test_name]["values"]): return float(_TEST_VALUES[test_name]["values"][index]), float( _TEST_VALUES[test_name]["cash"][index] ) except Exception as e: logger.warning("Error accessing test values: %s", e, exc_info=True) return None, None
[docs] def is_test_mode(): """Check if we're running in a test context""" if not hasattr(sys, "argv") or len(sys.argv) == 0: return False test_file = os.path.basename(sys.argv[0]) return test_file.startswith("test_") and test_file.endswith(".py")
[docs] def get_current_test_file(): """Get current test file name if in test mode""" if not is_test_mode(): return None return os.path.basename(sys.argv[0])