背景¶
backtrader 已经比较完善了,我想要借鉴量化投资框架中其他项目的优势,继续改进优化 backtrader。
任务¶
阅读研究分析 backtrader 这个项目的源代码,了解这个项目。
阅读研究分析/Users/yunjinqi/Documents/量化交易框架/DevilYuan
借鉴这个新项目的优点和功能,给 backtrader 优化改进提供新的建议
写需规文档和设计文档放到这个文档的最下面,方便后续借鉴
DevilYuan 项目简介¶
DevilYuan 是一个 A 股量化交易系统,具有以下核心特点:
完整系统: 从数据到交易的完整解决方案
GUI 界面: 提供 PyQt 图形界面
策略回放: 支持策略的历史回放
实盘对接: 支持券商实盘交易
数据管理: 完善的数据管理模块
监控系统: 实时监控和报警
重点借鉴方向¶
GUI 设计: PyQt 图形界面设计
事件引擎: 事件驱动引擎设计
策略回放: 策略回放功能实现
数据管理: 数据获取和存储管理
交易引擎: 交易执行引擎
监控报警: 监控和报警系统
框架对比分析¶
架构设计对比¶
| 维度 | backtrader | DevilYuan |
|——|———–|———–|
| 核心定位| 回测框架 | 完整交易系统 |
|用户界面| 无/命令行 | PyQt 图形界面 |
|事件驱动| 隐式回调 | 显式事件引擎 |
|并行处理| 多进程优化 | 多进程/多线程混合 |
|数据存储| 内存/文件 | MongoDB |
|实盘交易| 有限支持 | 多券商支持 |
|策略回放| 基本支持 | 周期分割并行 |
|监控报警| 无 | 微信/QQ 通知 |
|应用场景| 策略研发 | A 股实盘交易 |
backtrader 的优势¶
1.通用性强: 不依赖特定数据源,支持全球市场
API 简洁: 易于学习和使用的 Pythonic API
指标丰富: 内置 60+技术指标
社区活跃: 大量第三方扩展和文档
性能优化: LineBuffer 高效内存管理
DevilYuan 的优势¶
完整生态: 从数据获取到实盘交易的完整闭环
GUI 友好: 可视化操作降低使用门槛
A 股适配: 专门针对 A 股市场特性优化
事件驱动: 清晰的事件引擎架构
实时监控: 微信通知和策略监控
多账户: 支持多个模拟/实盘账户
需求规格文档¶
需求 1: PyQt 图形界面¶
需求描述*:
为 backtrader 添加可选的 PyQt 图形界面,提供可视化的策略配置、回测执行和结果分析功能。
功能需求*:
主窗口设计: 提供导航菜单,集成各功能模块
策略配置界面: 可视化配置策略参数、数据源、回测范围
回测执行界面: 显示回测进度、日志输出
结果分析界面: 图表展示、性能指标表格
实时监控界面: 策略运行状态实时更新
暗色主题: 支持暗色主题,长时间使用舒适
非功能需求*:
可选组件: GUI 不影响命令行使用
响应速度: 界面操作响应时间<100ms
内存占用: GUI 内存占用<200MB
需求 2: 增强事件引擎¶
需求描述*:
实现一个独立的事件引擎,支持事件注册、分发、定时器和多线程处理。
功能需求*:
事件注册: 动态注册事件处理器
事件分发: 支持同步和异步事件分发
定时器: 内置定时器功能,支持周期性任务
多线程: 支持多个事件处理线程并行处理
事件优先级: 支持事件优先级排序
非功能需求*:
线程安全: 事件引擎必须线程安全
性能要求: 事件处理延迟<10ms
可扩展性: 支持 100+事件类型
需求 3: 策略回放功能¶
需求描述*:
实现策略回放功能,支持按时间周期分割并行回测,提高回测效率。
功能需求*:
周期分割: 将回测时间段分割成多个周期
并行处理: 多进程并行处理不同周期
状态传递: 周期间正确传递持仓、资金状态
参数组合: 支持多参数组合并行回测
结果汇总: 自动汇总各周期的回测结果
非功能需求*:
性能提升: 并行回测速度提升 3 倍以上
结果一致性: 并行结果与串行结果完全一致
内存控制: 每个进程内存占用<500MB
需求 4: 数据管理模块¶
需求描述*:
建立统一的数据管理模块,支持多数据源和数据缓存。
功能需求*:
多数据源: 支持 CSV、Pandas、数据库等多种数据源
数据缓存: 内存缓存常用数据
数据验证: 数据完整性和一致性检查
自动更新: 支持数据自动下载和更新
数据转换: 自动处理复权、对齐等
非功能需求*:
向后兼容: 现有数据加载方式继续支持
性能要求: 数据加载速度提升 50%
需求 5: 实盘交易引擎¶
需求描述*:
增强实盘交易功能,支持多账户管理和实时交易监控。
功能需求*:
多账户管理: 统一管理多个交易账户
实时同步: 持仓、委托、成交实时更新
交易接口: 统一的买入、卖出、撤单接口
模拟交易: 支持多个模拟账号
交易记录: 交易数据持久化存储
非功能需求*:
接口兼容: 支持主流券商接口
稳定性: 交易过程零故障
延迟控制: 订单提交延迟<100ms
需求 6: 监控报警系统¶
需求描述*:
实现策略监控和报警功能,支持多种通知方式。
功能需求*:
策略监控: 实时监控策略运行状态
信号推送: 买卖信号实时推送
异常报警: 策略异常及时报警
多种通知: 支持邮件、微信、钉钉等通知方式
交互查询: 支持通过消息查询策略状态
非功能需求*:
实时性: 报警延迟<5 秒
可靠性: 报警送达率>99%
设计文档¶
1. PyQt 图形界面设计¶
1.1 整体架构¶
# backtrader/gui/__init__.py
"""
Backtrader GUI Module
提供 PyQt5 实现的图形界面,包含以下模块:
- BtMainWindow: 主窗口
- BtBasicMainWindow: 基础窗口类
- 各功能子窗口
"""
from .main_window import BtMainWindow
from .basic_window import BtBasicMainWindow
__all__ = ['BtMainWindow', 'BtBasicMainWindow']
```bash
#### 1.2 主窗口设计
```python
# backtrader/gui/main_window.py
from PyQt5.QtWidgets import (QMainWindow, QWidget, QVBoxLayout,
QPushButton, QGridLayout, QStatusBar)
from PyQt5.QtCore import Qt
import qdarkstyle
class BtMainWindow(QMainWindow):
"""Backtrader 主窗口
提供导航界面,包含主要功能入口按钮
"""
def __init__(self):
super().__init__()
self._initUi()
def _initUi(self):
self.setWindowTitle('Backtrader 量化交易平台')
self.resize(1000, 700)
# 应用暗色主题
self.setStyleSheet(qdarkstyle.load_stylesheet_pyqt5())
# 中央 Widget
centerWidget = QWidget()
self.setCentralWidget(centerWidget)
# 布局
layout = QGridLayout(centerWidget)
# 功能按钮
buttons = [
('数据管理', self._openDataMgr),
('策略回测', self._openBackTest),
('实时监控', self._openMonitor),
('系统设置', self._openSettings),
]
for i, (text, callback) in enumerate(buttons):
btn = QPushButton(text)
btn.setMinimumHeight(100)
btn.clicked.connect(callback)
layout.addWidget(btn, i // 2, i % 2)
# 状态栏
self._statusBar = QStatusBar()
self.setStatusBar(self._statusBar)
def _openDataMgr(self):
"""打开数据管理窗口"""
from .data_window import BtDataWindow
window = BtDataWindow(self)
window.show()
def _openBackTest(self):
"""打开回测窗口"""
from .backtest_window import BtBackTestWindow
window = BtBackTestWindow(self)
window.show()
def _openMonitor(self):
"""打开监控窗口"""
from .monitor_window import BtMonitorWindow
window = BtMonitorWindow(self)
window.show()
def _openSettings(self):
"""打开设置窗口"""
from .settings_window import BtSettingsWindow
window = BtSettingsWindow(self)
window.show()
```bash
#### 1.3 基础窗口类
```python
# backtrader/gui/basic_window.py
from PyQt5.QtWidgets import QMainWindow, QStatusBar, QTextEdit
from PyQt5.QtCore import Qt, QMutex, QMutexLocker
from ..events import BtEventEngine, BtEventType
class BtBasicMainWindow(QMainWindow):
"""基础窗口类
提供所有子窗口的通用功能:
- 事件引擎集成
- 互斥操作管理
- 状态栏更新
- 日志输出
"""
def __init__(self, parent=None):
super().__init__(parent)
self._eventEngine = None
self._mutexActions = {} # 互斥操作字典
self._mutex = QMutex()
self._runningAction = None
self._initUi()
def _initUi(self):
"""初始化 UI,子类重写"""
pass
def setEventEngine(self, eventEngine):
"""设置事件引擎"""
self._eventEngine = eventEngine
self._registerEvents()
def _registerEvents(self):
"""注册事件处理器,子类重写"""
pass
def _addMutexAction(self, actionName, actionFunc):
"""添加互斥操作
互斥操作指同一时间只能运行一个操作
"""
self._mutexActions[actionName] = actionFunc
def _startMutexAction(self, actionName):
"""启动互斥操作"""
with QMutexLocker(self._mutex):
if self._runningAction is not None:
return False
self._runningAction = actionName
actionFunc = self._mutexActions.get(actionName)
if actionFunc:
actionFunc()
return True
def _endMutexAction(self):
"""结束互斥操作"""
with QMutexLocker(self._mutex):
self._runningAction = None
def _info(self, msg, level='info'):
"""输出信息到状态栏"""
color = {
'info': 'black',
'success': 'green',
'warning': 'orange',
'error': 'red',
}.get(level, 'black')
self.statusBar().showMessage(f'<font color="{color}">{msg}</font>')
def _log(self, msg, level='info'):
"""输出日志到文本框"""
if hasattr(self, '_logTextEdit'):
color = {
'info': 'black',
'success': 'green',
'warning': 'orange',
'error': 'red',
}.get(level, 'black')
self._logTextEdit.append(
f'<font color="{color}">{msg}</font>'
)
```bash
#### 1.4 回测窗口
```python
# backtrader/gui/backtest_window.py
from PyQt5.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout,
QPushButton, QLineEdit, QLabel, QComboBox,
QTextEdit, QProgressBar, QTableWidget,
QTableWidgetItem, QSplitter)
from PyQt5.QtCore import Qt, QThread
from .basic_window import BtBasicMainWindow
class BtBackTestWindow(BtBasicMainWindow):
"""策略回测窗口"""
def _initUi(self):
self.setWindowTitle('策略回测')
self.resize(1200, 800)
centerWidget = QWidget()
self.setCentralWidget(centerWidget)
layout = QVBoxLayout(centerWidget)
# 参数配置区
paramLayout = self._createParamLayout()
layout.addLayout(paramLayout)
# 分割器: 日志和结果
splitter = QSplitter(Qt.Vertical)
# 日志区
self._logTextEdit = QTextEdit()
self._logTextEdit.setReadOnly(True)
splitter.addWidget(self._logTextEdit)
# 结果表格
self._resultTable = QTableWidget()
splitter.addWidget(self._resultTable)
splitter.setStretchFactor(0, 1)
splitter.setStretchFactor(1, 2)
layout.addWidget(splitter)
# 进度条
self._progressBar = QProgressBar()
layout.addWidget(self._progressBar)
def _createParamLayout(self):
"""创建参数配置布局"""
layout = QHBoxLayout()
# 策略选择
layout.addWidget(QLabel('策略:'))
self._strategyCombo = QComboBox()
self._strategyCombo.addItems(['SmaCross', 'BollingerBands', 'RSI'])
layout.addWidget(self._strategyCombo)
# 开始日期
layout.addWidget(QLabel('开始日期:'))
self._startDateEdit = QLineEdit('2020-01-01')
layout.addWidget(self._startDateEdit)
# 结束日期
layout.addWidget(QLabel('结束日期:'))
self._endDateEdit = QLineEdit('2023-12-31')
layout.addWidget(self._endDateEdit)
# 初始资金
layout.addWidget(QLabel('初始资金:'))
self._cashEdit = QLineEdit('1000000')
layout.addWidget(self._cashEdit)
# 运行按钮
self._runBtn = QPushButton('运行回测')
self._runBtn.clicked.connect(self._runBackTest)
layout.addWidget(self._runBtn)
layout.addStretch()
return layout
def _runBackTest(self):
"""运行回测"""
if not self._startMutexAction('backtest'):
self._info('回测正在运行中...', 'warning')
return
# 创建回测线程
self._backtestThread = BtBackTestThread(
strategy=self._strategyCombo.currentText(),
start_date=self._startDateEdit.text(),
end_date=self._endDateEdit.text(),
cash=float(self._cashEdit.text()),
)
self._backtestThread.finished.connect(self._onBackTestFinished)
self._backtestThread.progress.connect(self._onProgress)
self._backtestThread.log.connect(self._log)
self._backtestThread.start()
def _onBackTestFinished(self, results):
"""回测完成"""
self._endMutexAction()
self._info('回测完成', 'success')
# 显示结果
self._displayResults(results)
def _onProgress(self, value):
"""更新进度"""
self._progressBar.setValue(value)
def _displayResults(self, results):
"""显示回测结果"""
self._resultTable.setRowCount(len(results))
self._resultTable.setColumnCount(2)
self._resultTable.setHorizontalHeaderLabels(['指标', '值'])
for i, (name, value) in enumerate(results.items()):
self._resultTable.setItem(i, 0, QTableWidgetItem(name))
self._resultTable.setItem(i, 1, QTableWidgetItem(str(value)))
class BtBackTestThread(QThread):
"""回测线程"""
finished = pyqtSignal(dict)
progress = pyqtSignal(int)
log = pyqtSignal(str, str)
def __init__(self, strategy, start_date, end_date, cash):
super().__init__()
self.strategy = strategy
self.start_date = start_date
self.end_date = end_date
self.cash = cash
def run(self):
"""执行回测"""
import backtrader as bt
self.log.emit(f'开始回测: {self.strategy}', 'info')
# 创建 Cerebro
cerebro = bt.Cerebro()
cerebro.broker.setcash(self.cash)
# 加载策略
# ... 加载数据和策略 ...
# 运行回测
self.progress.emit(50)
results = cerebro.run()
self.progress.emit(100)
# 提取结果
result_dict = {
'总收益率': f'{results[0].analyzers.pnl.get_analysis()}%',
'夏普比率': results[0].analyzers.sharpe.get_analysis(),
'最大回撤': results[0].analyzers.drawdown.get_analysis(),
}
self.finished.emit(result_dict)
```bash
### 2. 增强事件引擎设计
#### 2.1 核心事件引擎
```python
# backtrader/events/engine.py
import threading
import queue
from collections import defaultdict
from typing import Callable, Dict, List, Any, Optional
from .event import BtEvent, BtEventType
class BtEventEngine(threading.Thread):
"""Backtrader 事件引擎
支持事件注册、分发、定时器和多线程处理
"""
def __init__(self, hand_nbr: int = 2, timer: bool = True):
super().__init__()
self._active = False
self._hand_nbr = hand_nbr
self._timer = timer
# 事件队列
self._queue = queue.Queue()
# 事件处理器
self._hands: List[BtEventHand] = []
self._hand_queues: List[queue.Queue] = []
# 事件映射: {事件类型: [处理器列表]}
self._event_map: Dict[BtEventType, List[Callable]] = defaultdict(list)
# 定时器
self._timer_hand: Optional[BtTimerHand] = None
def start(self):
"""启动事件引擎"""
if not self._active:
self._active = True
# 创建事件处理器
for i in range(self._hand_nbr):
hand_queue = queue.Queue()
hand = BtEventHand(self, hand_queue, i)
hand.start()
self._hands.append(hand)
self._hand_queues.append(hand_queue)
# 创建定时器
if self._timer:
self._timer_hand = BtTimerHand(self)
self._timer_hand.start()
# 启动引擎线程
super().start()
def stop(self):
"""停止事件引擎"""
if self._active:
self._active = False
# 停止事件处理器
for hand in self._hands:
hand.stop()
# 停止定时器
if self._timer_hand:
self._timer_hand.stop()
# 放入停止事件
self.put(BtEvent(BtEventType.stop))
def run(self):
"""事件引擎主循环"""
while self._active:
try:
# 从队列获取事件
event = self._queue.get(timeout=0.1)
self._processEvent(event)
except queue.Empty:
continue
def put(self, event: BtEvent):
"""放入事件到队列"""
self._queue.put(event)
def register(self, event_type: BtEventType, handler: Callable,
hand: Optional[int] = None):
"""注册事件处理器
Args:
event_type: 事件类型
handler: 处理函数
hand: 指定事件处理器索引,None 表示自动分配
"""
# 创建注册事件
reg_event = BtEvent(BtEventType.register)
reg_event.data = {
'type': event_type,
'handler': handler,
'hand': hand,
}
self.put(reg_event)
def unregister(self, event_type: BtEventType, handler: Callable):
"""取消注册事件处理器"""
if event_type in self._event_map:
if handler in self._event_map[event_type]:
self._event_map[event_type].remove(handler)
def registerTimer(self, handler: Callable, interval: int = 1,
hand: Optional[int] = None):
"""注册定时器
Args:
handler: 处理函数
interval: 触发间隔(秒)
hand: 指定事件处理器索引
"""
reg_event = BtEvent(BtEventType.registerTimer)
reg_event.data = {
'handler': handler,
'interval': interval,
'hand': hand,
}
self.put(reg_event)
def _processEvent(self, event: BtEvent):
"""处理事件"""
if event.type == BtEventType.stop:
return
elif event.type == BtEventType.register:
# 注册事件处理器
data = event.data
event_type = data['type']
handler = data['handler']
hand_idx = data.get('hand')
# 添加到事件映射
if handler not in self._event_map[event_type]:
self._event_map[event_type].append(handler)
# 如果指定了处理器,添加到对应处理器
if hand_idx is not None:
self._event_map[f'_hand_{hand_idx}_{event_type}'] = [handler]
elif event.type == BtEventType.registerTimer:
# 注册定时器
if self._timer_hand:
self._timer_hand.register(event.data)
else:
# 普通事件,分发到对应处理器
handlers = self._event_map.get(event.type, [])
# 负载均衡分配到不同处理器
for i, handler in enumerate(handlers):
hand_idx = i % len(self._hand_queues)
# 复制事件,避免修改原始事件
event_copy = BtEvent(event.type, event.data)
self._hand_queues[hand_idx].put((handler, event_copy))
class BtEventHand(threading.Thread):
"""事件处理器线程"""
def __init__(self, engine: BtEventEngine, queue: queue.Queue, idx: int):
super().__init__()
self._engine = engine
self._queue = queue
self._idx = idx
self._active = False
# 事件映射
self._handlers: Dict[BtEventType, List[Callable]] = defaultdict(list)
def start(self):
"""启动事件处理器"""
if not self._active:
self._active = True
super().start()
def stop(self):
"""停止事件处理器"""
self._active = False
def run(self):
"""事件处理器主循环"""
while self._active:
try:
# 获取事件
handler, event = self._queue.get(timeout=0.1)
# 执行处理器
try:
handler(event)
except Exception as e:
# 错误处理
print(f"Event handler error: {e}")
except queue.Empty:
continue
class BtTimerHand(threading.Thread):
"""定时器线程"""
def __init__(self, engine: BtEventEngine):
super().__init__()
self._engine = engine
self._active = False
self._timers: List[Dict] = []
def start(self):
"""启动定时器"""
if not self._active:
self._active = True
super().start()
def stop(self):
"""停止定时器"""
self._active = False
def register(self, data: Dict):
"""注册定时器"""
self._timers.append({
'handler': data['handler'],
'interval': data.get('interval', 1),
'next_time': 0,
})
def run(self):
"""定时器主循环"""
import time
while self._active:
current_time = time.time()
for timer in self._timers:
if current_time >= timer['next_time']:
# 触发定时器
try:
timer['handler'](current_time)
except Exception as e:
print(f"Timer handler error: {e}")
# 更新下次触发时间
timer['next_time'] = current_time + timer['interval']
time.sleep(0.1) # 100ms 检查间隔
```bash
#### 2.2 事件定义
```python
# backtrader/events/event.py
from enum import Enum
from typing import Any, Dict
from dataclasses import dataclass, field
class BtEventType(Enum):
"""事件类型枚举"""
# 系统事件
stop = "stop"
register = "register"
registerTimer = "registerTimer"
# 数据事件
dataLoading = "dataLoading"
dataLoaded = "dataLoaded"
dataError = "dataError"
# 回测事件
backtestStart = "backtestStart"
backtestProgress = "backtestProgress"
backtestFinish = "backtestFinish"
backtestError = "backtestError"
# 策略事件
strategyStart = "strategyStart"
strategyStop = "strategyStop"
strategySignal = "strategySignal"
# 订单事件
orderSubmitted = "orderSubmitted"
orderAccepted = "orderAccepted"
orderRejected = "orderRejected"
orderFilled = "orderFilled"
orderCanceled = "orderCanceled"
# 交易事件
tradeOpened = "tradeOpened"
tradeClosed = "tradeClosed"
# 监控事件
monitorAlert = "monitorAlert"
monitorSignal = "monitorSignal"
@dataclass
class BtEvent:
"""事件对象"""
type: BtEventType
data: Dict[str, Any] = field(default_factory=dict)
def __getitem__(self, key: str) -> Any:
"""获取事件数据"""
return self.data.get(key)
def __setitem__(self, key: str, value: Any):
"""设置事件数据"""
self.data[key] = value
```bash
### 3. 策略回放功能设计
#### 3.1 周期分割并行回测
```python
# backtrader/parallel.py
import multiprocessing as mp
from typing import List, Dict, Any, Tuple, Optional
from datetime import datetime
import numpy as np
import pandas as pd
class BtParallelBackTest:
"""并行回测引擎
支持周期分割和参数组合并行回测
"""
def __init__(self, cerebro, period_nbr: int = 4, param_group_nbr: int = None):
"""初始化并行回测引擎
Args:
cerebro: Cerebro 实例
period_nbr: 周期分割数量
param_group_nbr: 参数组合并行数量
"""
self._cerebro = cerebro
self._period_nbr = period_nbr
self._param_group_nbr = param_group_nbr
def runPeriods(self, start_date: datetime, end_date: datetime,
data feeds: List) -> List[Dict]:
"""周期分割并行回测
将回测时间段分割成多个周期,分别回测后合并结果
Args:
start_date: 开始日期
end_date: 结束日期
data_feeds: 数据源列表
Returns:
合并后的回测结果列表
"""
# 获取交易日列表
trade_days = self._getTradeDays(start_date, end_date, data_feeds)
# 分割周期
periods = self._splitPeriods(trade_days, self._period_nbr)
# 并行回测
with mp.Pool(processes=self._period_nbr) as pool:
results = pool.starmap(
self._runSinglePeriod,
[(period, data_feeds) for period in periods]
)
# 合并结果
return self._mergeResults(results)
def _getTradeDays(self, start_date: datetime, end_date: datetime,
data_feeds: List) -> List[datetime]:
"""获取交易日列表"""
# 从数据源提取交易日
trade_days = []
for feed in data_feeds:
if hasattr(feed, 'datetime'):
dates = pd.to_datetime([feed.datetime[i] for i in range(len(feed))])
trade_days.extend(dates.tolist())
# 去重排序
trade_days = sorted(list(set(trade_days)))
trade_days = [d for d in trade_days if start_date <= d <= end_date]
return trade_days
def _splitPeriods(self, trade_days: List[datetime],
n: int) -> List[Tuple[datetime, datetime]]:
"""分割周期"""
if not trade_days:
return []
period_size = (len(trade_days) + n - 1) // n
periods = []
for i in range(0, len(trade_days), period_size):
period_days = trade_days[i:i + period_size]
if period_days:
periods.append((period_days[0], period_days[-1]))
return periods
def _runSinglePeriod(self, period: Tuple[datetime, datetime],
data_feeds: List) -> Dict:
"""运行单个周期的回测"""
start_date, end_date = period
# 创建新的 Cerebro 实例
cerebro = self._createCerebroCopy()
# 过滤数据
filtered_feeds = self._filterDataByDate(
data_feeds, start_date, end_date
)
# 添加数据
for feed in filtered_feeds:
cerebro.adddata(feed)
# 运行回测
results = cerebro.run()
return {
'period': period,
'results': results,
'final_value': cerebro.broker.getvalue(),
'final_cash': cerebro.broker.getcash(),
}
def _createCerebroCopy(self):
"""创建 Cerebro 副本"""
# 使用 pickle 序列化创建副本
import pickle
return pickle.loads(pickle.dumps(self._cerebro))
def _filterDataByDate(self, data_feeds: List, start_date: datetime,
end_date: datetime) -> List:
"""按日期过滤数据"""
filtered = []
for feed in data_feeds:
# 创建数据过滤
from .filters import DateFilter
filtered_feed = DateFilter(feed, start_date, end_date)
filtered.append(filtered_feed)
return filtered
def _mergeResults(self, results: List[Dict]) -> List[Dict]:
"""合并回测结果"""
if not results:
return []
# 合并统计指标
merged = {
'total_return': 0,
'sharpe_ratio': 0,
'max_drawdown': 0,
'total_trades': 0,
}
# 计算加权平均
total_value = sum(r['final_value'] for r in results)
for result in results:
weight = result['final_value'] / total_value if total_value > 0 else 0
# 累加指标
for key in merged.keys():
if key in result:
merged[key] += result[key] * weight
return [merged]
def runParamGroups(self, param_grid: Dict[str, List]) -> List[Dict]:
"""参数组合并行回测
Args:
param_grid: 参数字典 {参数名: [值列表]}
Returns:
各参数组合的回测结果
"""
# 生成参数组合
param_groups = self._createParamGroups(param_grid)
# 并行回测
with mp.Pool(processes=self._param_group_nbr or mp.cpu_count()) as pool:
results = pool.map(self._runSingleParamGroup, param_groups)
return results
def _createParamGroups(self, param_grid: Dict[str, List]) -> List[Dict]:
"""创建参数组合"""
import itertools
keys = list(param_grid.keys())
values = [param_grid[k] for k in keys]
combinations = itertools.product(*values)
return [dict(zip(keys, combo)) for combo in combinations]
def _runSingleParamGroup(self, params: Dict) -> Dict:
"""运行单个参数组合的回测"""
# 创建 Cerebro 副本
cerebro = self._createCerebroCopy()
# 设置参数
strategy = cerebro.runstrategies[0]
for key, value in params.items():
setattr(strategy.params, key, value)
# 运行回测
results = cerebro.run()
return {
'params': params,
'results': results,
}
```bash
### 4. 数据管理模块设计
#### 4.1 数据管理器
```python
# backtrader/data/manager.py
from typing import List, Dict, Any, Optional, Union
from datetime import datetime
import pandas as pd
from pathlib import Path
import pickle
import hashlib
class BtDataManager:
"""数据管理器
提供统一的数据获取、缓存和管理功能
"""
def __init__(self, cache_dir: Optional[str] = None):
"""初始化数据管理器
Args:
cache_dir: 缓存目录
"""
self._cache_dir = Path(cache_dir) if cache_dir else Path.home() / '.backtrader' / 'cache'
self._cache_dir.mkdir(parents=True, exist_ok=True)
# 数据缓存
self._memory_cache: Dict[str, Any] = {}
# 数据源配置
self._data_sources: Dict[str, Any] = {}
def registerDataSource(self, name: str, source: Any):
"""注册数据源"""
self._data_sources[name] = source
def loadData(self, name: str, start: Optional[datetime] = None,
end: Optional[datetime] = None,
use_cache: bool = True) -> Any:
"""加载数据
Args:
name: 数据名称或路径
start: 开始日期
end: 结束日期
use_cache: 是否使用缓存
Returns:
数据对象
"""
cache_key = self._getCacheKey(name, start, end)
# 检查内存缓存
if use_cache and cache_key in self._memory_cache:
return self._memory_cache[cache_key]
# 检查磁盘缓存
if use_cache:
cached_data = self._loadFromDiskCache(cache_key)
if cached_data is not None:
self._memory_cache[cache_key] = cached_data
return cached_data
# 从数据源加载
data = self._loadFromSource(name, start, end)
# 保存到缓存
if use_cache and data is not None:
self._saveToDiskCache(cache_key, data)
self._memory_cache[cache_key] = data
return data
def _loadFromSource(self, name: str, start: Optional[datetime],
end: Optional[datetime]) -> Any:
"""从数据源加载数据"""
# 检查是否为已注册数据源
if name in self._data_sources:
source = self._data_sources[name]
return source.load(start, end)
# 尝试从文件加载
path = Path(name)
if path.exists():
suffix = path.suffix.lower()
if suffix == '.csv':
return self._loadCsv(path, start, end)
elif suffix in ['.pkl', '.pickle']:
return self._loadPickle(path)
elif suffix in ['.h5', '.hdf5']:
return self._loadHdf5(path, start, end)
raise ValueError(f"Cannot load data: {name}")
def _loadCsv(self, path: Path, start: Optional[datetime],
end: Optional[datetime]) -> Any:
"""加载 CSV 数据"""
df = pd.read_csv(path)
# 转换日期
if 'datetime' in df.columns:
df['datetime'] = pd.to_datetime(df['datetime'])
elif 'date' in df.columns:
df['datetime'] = pd.to_datetime(df['date'])
# 过滤日期范围
if start is not None:
df = df[df['datetime'] >= start]
if end is not None:
df = df[df['datetime'] <= end]
# 转换为 backtrader 数据源
from ..feeds import PandasData
return PandasData(dataname=df)
def _loadPickle(self, path: Path) -> Any:
"""加载 Pickle 数据"""
with open(path, 'rb') as f:
return pickle.load(f)
def _loadHdf5(self, path: Path, start: Optional[datetime],
end: Optional[datetime]) -> Any:
"""加载 HDF5 数据"""
import tables
with tables.open_file(path, 'r') as h5file:
# 读取数据
# ...
return data
def _getCacheKey(self, name: str, start: Optional[datetime],
end: Optional[datetime]) -> str:
"""生成缓存键"""
key_str = f"{name}_{start}_{end}"
return hashlib.md5(key_str.encode()).hexdigest()
def _loadFromDiskCache(self, cache_key: str) -> Optional[Any]:
"""从磁盘缓存加载"""
cache_path = self._cache_dir / f"{cache_key}.pkl"
if cache_path.exists():
try:
with open(cache_path, 'rb') as f:
return pickle.load(f)
except Exception:
return None
return None
def _saveToDiskCache(self, cache_key: str, data: Any):
"""保存到磁盘缓存"""
cache_path = self._cache_dir / f"{cache_key}.pkl"
try:
with open(cache_path, 'wb') as f:
pickle.dump(data, f)
except Exception:
pass
def clearCache(self):
"""清空缓存"""
self._memory_cache.clear()
# 清空磁盘缓存
for cache_file in self._cache_dir.glob('*.pkl'):
try:
cache_file.unlink()
except Exception:
pass
def updateData(self, name: str):
"""更新数据
从数据源重新下载并更新数据
"""
# 清除缓存
cache_key = self._getCacheKey(name, None, None)
if cache_key in self._memory_cache:
del self._memory_cache[cache_key]
# 重新加载数据
return self.loadData(name, use_cache=False)
def validateData(self, data: Any) -> bool:
"""验证数据完整性"""
# 检查必需字段
required_fields = ['datetime', 'open', 'high', 'low', 'close', 'volume']
for field in required_fields:
if not hasattr(data, field):
return False
# 检查数据一致性
# ...
return True
```bash
### 5. 实盘交易引擎设计
#### 5.1 多账户管理
```python
# backtrader/trading/account_manager.py
from typing import Dict, List, Optional, Any
from enum import Enum
import threading
class AccountType(Enum):
"""账户类型"""
SIMULATION = "simulation" # 模拟账户
REAL = "real" # 实盘账户
class AccountStatus(Enum):
"""账户状态"""
IDLE = "idle" # 空闲
CONNECTING = "connecting" # 连接中
CONNECTED = "connected" # 已连接
DISCONNECTED = "disconnected" # 已断开
ERROR = "error" # 错误
class BtAccount:
"""交易账户"""
def __init__(self, account_id: str, account_type: AccountType,
broker: Any):
"""初始化账户
Args:
account_id: 账户 ID
account_type: 账户类型
broker: 券商接口
"""
self._account_id = account_id
self._account_type = account_type
self._broker = broker
self._status = AccountStatus.IDLE
self._cash = 0
self._value = 0
self._positions = {} # {symbol: position}
self._orders = {} # {order_id: order}
self._trades = [] # 历史成交
@property
def account_id(self) -> str:
return self._account_id
@property
def account_type(self) -> AccountType:
return self._account_type
@property
def status(self) -> AccountStatus:
return self._status
@property
def cash(self) -> float:
return self._cash
@property
def value(self) -> float:
return self._value
@property
def positions(self) -> Dict:
return self._positions.copy()
def connect(self) -> bool:
"""连接券商"""
self._status = AccountStatus.CONNECTING
try:
if self._broker.connect():
self._status = AccountStatus.CONNECTED
# 同步账户信息
self._syncAccount()
return True
except Exception as e:
self._status = AccountStatus.ERROR
print(f"Connection error: {e}")
return False
def disconnect(self):
"""断开连接"""
if self._broker:
self._broker.disconnect()
self._status = AccountStatus.DISCONNECTED
def buy(self, symbol: str, price: float, quantity: int) -> Optional[str]:
"""买入
Args:
symbol: 交易标的
price: 价格
quantity: 数量
Returns:
订单 ID
"""
if self._status != AccountStatus.CONNECTED:
print("Account not connected")
return None
order_id = self._broker.buy(symbol, price, quantity)
if order_id:
self._orders[order_id] = {
'order_id': order_id,
'symbol': symbol,
'direction': 'buy',
'price': price,
'quantity': quantity,
'status': 'submitted',
}
return order_id
def sell(self, symbol: str, price: float, quantity: int) -> Optional[str]:
"""卖出"""
if self._status != AccountStatus.CONNECTED:
print("Account not connected")
return None
order_id = self._broker.sell(symbol, price, quantity)
if order_id:
self._orders[order_id] = {
'order_id': order_id,
'symbol': symbol,
'direction': 'sell',
'price': price,
'quantity': quantity,
'status': 'submitted',
}
return order_id
def cancelOrder(self, order_id: str) -> bool:
"""撤销订单"""
if order_id not in self._orders:
return False
if self._broker.cancel(order_id):
self._orders[order_id]['status'] = 'canceled'
return True
return False
def _syncAccount(self):
"""同步账户信息"""
# 获取资金
self._cash = self._broker.getCash()
self._value = self._broker.getValue()
# 获取持仓
self._positions = self._broker.getPositions()
# 获取订单
self._orders = self._broker.getOrders()
def update(self):
"""更新账户状态"""
if self._status == AccountStatus.CONNECTED:
self._syncAccount()
class BtAccountManager:
"""多账户管理器
统一管理多个交易账户
"""
def __init__(self):
self._accounts: Dict[str, BtAccount] = {}
self._lock = threading.Lock()
def addAccount(self, account_id: str, account_type: AccountType,
broker: Any) -> BtAccount:
"""添加账户"""
with self._lock:
if account_id in self._accounts:
raise ValueError(f"Account {account_id} already exists")
account = BtAccount(account_id, account_type, broker)
self._accounts[account_id] = account
return account
def removeAccount(self, account_id: str):
"""移除账户"""
with self._lock:
if account_id in self._accounts:
self._accounts[account_id].disconnect()
del self._accounts[account_id]
def getAccount(self, account_id: str) -> Optional[BtAccount]:
"""获取账户"""
return self._accounts.get(account_id)
def getAllAccounts(self) -> List[BtAccount]:
"""获取所有账户"""
return list(self._accounts.values())
def connectAll(self) -> bool:
"""连接所有账户"""
success = True
for account in self._accounts.values():
if not account.connect():
success = False
return success
def disconnectAll(self):
"""断开所有账户"""
for account in self._accounts.values():
account.disconnect()
def updateAll(self):
"""更新所有账户状态"""
for account in self._accounts.values():
account.update()
@property
def totalCash(self) -> float:
"""所有账户总现金"""
return sum(acc.cash for acc in self._accounts.values())
@property
def totalValue(self) -> float:
"""所有账户总市值"""
return sum(acc.value for acc in self._accounts.values())
```bash
### 6. 监控报警系统设计
#### 6.1 监控引擎
```python
# backtrader/monitor/engine.py
from typing import Callable, Dict, List, Any, Optional
from datetime import datetime
import threading
import smtplib
from email.mime.text import MIMEText
class BtMonitorEngine:
"""监控引擎
监控策略运行状态,发送报警通知
"""
def __init__(self, event_engine):
"""初始化监控引擎
Args:
event_engine: 事件引擎
"""
self._event_engine = event_engine
self._monitors: Dict[str, BtMonitor] = {}
self._notifiers: List[BtNotifier] = []
# 注册事件
self._registerEvents()
def _registerEvents(self):
"""注册事件处理器"""
self._event_engine.register(
BtEventType.strategySignal,
self._onStrategySignal
)
self._event_engine.register(
BtEventType.monitorAlert,
self._onAlert
)
def addMonitor(self, name: str, monitor: 'BtMonitor'):
"""添加监控器"""
self._monitors[name] = monitor
def removeMonitor(self, name: str):
"""移除监控器"""
if name in self._monitors:
del self._monitors[name]
def addNotifier(self, notifier: 'BtNotifier'):
"""添加通知器"""
self._notifiers.append(notifier)
def _onStrategySignal(self, event):
"""处理策略信号事件"""
signal = event.data
self._notify(
title=f"策略信号: {signal.get('strategy')}",
message=f"标的: {signal.get('symbol')}\n"
f"方向: {signal.get('direction')}\n"
f"价格: {signal.get('price')}\n"
f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
def _onAlert(self, event):
"""处理报警事件"""
alert = event.data
self._notify(
title=f"报警: {alert.get('level', 'INFO')}",
message=alert.get('message', ''),
level=alert.get('level', 'INFO')
)
def _notify(self, title: str, message: str, level: str = 'INFO'):
"""发送通知"""
for notifier in self._notifiers:
try:
notifier.send(title, message, level)
except Exception as e:
print(f"Notifier error: {e}")
def start(self):
"""启动监控"""
for monitor in self._monitors.values():
monitor.start()
def stop(self):
"""停止监控"""
for monitor in self._monitors.values():
monitor.stop()
class BtMonitor:
"""监控器基类"""
def __init__(self, name: str, check_interval: int = 60):
"""初始化监控器
Args:
name: 监控器名称
check_interval: 检查间隔(秒)
"""
self._name = name
self._check_interval = check_interval
self._active = False
self._thread = None
def start(self):
"""启动监控"""
if not self._active:
self._active = True
self._thread = threading.Thread(target=self._runLoop)
self._thread.daemon = True
self._thread.start()
def stop(self):
"""停止监控"""
self._active = False
if self._thread:
self._thread.join(timeout=5)
def _runLoop(self):
"""监控循环"""
import time
while self._active:
try:
self.check()
except Exception as e:
print(f"Monitor {self._name} error: {e}")
time.sleep(self._check_interval)
def check(self):
"""检查条件,子类实现"""
raise NotImplementedError
class BtStrategyMonitor(BtMonitor):
"""策略监控器
监控策略运行状态
"""
def __init__(self, strategy, event_engine):
super().__init__(f"StrategyMonitor_{id(strategy)}")
self._strategy = strategy
self._event_engine = event_engine
self._last_value = None
def check(self):
"""检查策略状态"""
# 检查策略是否还在运行
if not hasattr(self._strategy, 'isrunning') or not self._strategy.isrunning:
self._event_engine.put(BtEvent(
BtEventType.monitorAlert,
{
'level': 'WARNING',
'message': f'策略 {self._strategy.__class__.__name__} 已停止运行'
}
))
# 检查策略价值变化
current_value = self._strategy.broker.getvalue()
if self._last_value is not None:
change_pct = (current_value - self._last_value) / self._last_value
# 单日跌幅超过 5%报警
if change_pct < -0.05:
self._event_engine.put(BtEvent(
BtEventType.monitorAlert,
{
'level': 'CRITICAL',
'message': f'策略单日跌幅: {change_pct*100:.2f}%'
}
))
self._last_value = current_value
class BtNotifier:
"""通知器基类"""
def send(self, title: str, message: str, level: str = 'INFO'):
"""发送通知
Args:
title: 标题
message: 消息内容
level: 级别
"""
raise NotImplementedError
class BtEmailNotifier(BtNotifier):
"""邮件通知器"""
def __init__(self, smtp_server: str, from_addr: str, password: str,
to_addrs: List[str]):
"""初始化邮件通知器
Args:
smtp_server: SMTP 服务器
from_addr: 发件人地址
password: 密码
to_addrs: 收件人地址列表
"""
self._smtp_server = smtp_server
self._from_addr = from_addr
self._password = password
self._to_addrs = to_addrs
def send(self, title: str, message: str, level: str = 'INFO'):
"""发送邮件"""
msg = MIMEText(message)
msg['Subject'] = f"[{level}] {title}"
msg['From'] = self._from_addr
msg['To'] = ', '.join(self._to_addrs)
try:
with smtplib.SMTP(self._smtp_server, 587) as server:
server.starttls()
server.login(self._from_addr, self._password)
server.send_message(msg)
except Exception as e:
print(f"Email send error: {e}")
class BtLogNotifier(BtNotifier):
"""日志通知器"""
def __init__(self, log_file: str):
"""初始化日志通知器"""
self._log_file = log_file
def send(self, title: str, message: str, level: str = 'INFO'):
"""写入日志"""
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(self._log_file, 'a') as f:
f.write(f"{timestamp} [{level}] {title}\n{message}\n\n")
```bash
#### 6.2 微信通知器(可选)
```python
# backtrader/monitor/wechat.py
class BtWeChatNotifier(BtNotifier):
"""微信通知器
使用企业微信或微信测试号发送通知
"""
def __init__(self, webhook_url: str):
"""初始化微信通知器
Args:
webhook_url: 企业微信机器人 Webhook URL
"""
self._webhook_url = webhook_url
def send(self, title: str, message: str, level: str = 'INFO'):
"""发送微信通知"""
import requests
import json
# 根据级别选择颜色
colors = {
'INFO': 'info',
'WARNING': 'warning',
'ERROR': 'comment',
'CRITICAL': 'warning',
}
data = {
"msgtype": "markdown",
"markdown": {
"content": f"## <font color=\"{colors.get(level, 'info')}\">{title}</font>\n\n"
f"> {message.replace(chr(10), '> ')}"
}
}
try:
response = requests.post(
self._webhook_url,
data=json.dumps(data),
headers={'Content-Type': 'application/json'}
)
return response.status_code == 200
except Exception as e:
print(f"WeChat send error: {e}")
return False
```bash
### 7. 实施计划
#### 7.1 实施优先级
1. **高优先级**(第一阶段)
- 增强事件引擎 - 基础设施
- 数据管理模块 - 提升数据管理能力
2.**中优先级**(第二阶段)
- 策略回放功能 - 提升回测效率
- 监控报警系统 - 增强系统可靠性
3.**可选优先级** (第三阶段)
- PyQt 图形界面 - 提升用户体验
- 实盘交易引擎 - 扩展实盘能力
#### 7.2 向后兼容性保证
所有新功能都是**可选的**,现有代码无需修改即可继续使用:
```python
# 现有用法继续支持
cerebro = bt.Cerebro()
cerebro.adddata(data)
cerebro.addstrategy(MyStrategy)
results = cerebro.run()
# 新用法
# 事件引擎
event_engine = bt.BtEventEngine()
cerebro.set_event_engine(event_engine)
# 数据管理器
data_mgr = bt.BtDataManager()
data = data_mgr.loadData('AAPL', start, end)
# 并行回测
parallel = bt.BtParallelBackTest(cerebro)
results = parallel.runPeriods(start, end, [data])
```bash
#### 7.3 目录结构
```bash
backtrader/
├── __init__.py
├── cerebro.py # 核心引擎 (修改)
├── events/ # 新增: 事件系统
│ ├── __init__.py
│ ├── engine.py # 事件引擎
│ └── event.py # 事件定义
├── data/ # 修改: 数据模块
│ ├── manager.py # 新增: 数据管理器
│ └── ...
├── parallel.py # 新增: 并行回测
├── trading/ # 新增: 交易模块
│ ├── __init__.py
│ ├── account_manager.py # 账户管理
│ └── broker/ # 券商接口
├── monitor/ # 新增: 监控模块
│ ├── __init__.py
│ ├── engine.py # 监控引擎
│ └── wechat.py # 微信通知
└── gui/ # 新增: 图形界面
├── __init__.py
├── main_window.py # 主窗口
├── basic_window.py # 基础窗口
├── backtest_window.py # 回测窗口
├── data_window.py # 数据窗口
├── monitor_window.py # 监控窗口
└── settings_window.py # 设置窗口
```bash
- --
## 总结
通过借鉴 DevilYuan 的设计思想,backtrader 可以在保持通用性的同时,获得以下改进:
1. **GUI 支持**: PyQt 图形界面提升用户体验,降低使用门槛
2. **事件驱动**: 清晰的事件引擎架构,实现组件间松耦合
3. **并行回测**: 周期分割并行处理,大幅提升回测效率
4. **数据管理**: 统一的数据管理模块,支持多数据源和缓存
5. **实盘交易**: 增强的交易引擎,支持多账户管理
6. **监控报警**: 完善的监控报警系统,实时掌握策略状态
这些改进都是**向后兼容**的,用户可以按需使用新功能,不影响现有策略代码。DevilYuan 作为针对 A 股市场的完整交易系统,其在事件驱动架构、GUI 设计和实盘交易方面的实践经验对 backtrader 的扩展具有重要参考价值。