diff --git a/src/aitrader/main.py b/src/aitrader/main.py index 0917edf..bab2536 100644 --- a/src/aitrader/main.py +++ b/src/aitrader/main.py @@ -67,7 +67,7 @@ def run_tick( last = float(snap.ticker.get("last") or snap.ohlcv[settings.timeframes[0]]["close"].iloc[-1]) last_prices[symbol] = last - closed_ids = portfolio.check_stop_take_profit(settings, symbol, last) + closed_ids = portfolio.check_stop_take_profit(settings, kraken, symbol, last) if closed_ids: log.info("trade.sl_tp_closed", symbol=symbol, trade_ids=closed_ids) diff --git a/src/aitrader/notify/bot.py b/src/aitrader/notify/bot.py index dbf777c..d2093fc 100644 --- a/src/aitrader/notify/bot.py +++ b/src/aitrader/notify/bot.py @@ -111,16 +111,26 @@ class _AitraderBot(discord.Client): return deleted +_bot_instance: _AitraderBot | None = None + + +def get_bot_instance() -> _AitraderBot | None: + return _bot_instance + + def start_bot(settings: Settings) -> None: + global _bot_instance if not settings.discord_bot_token: log.info("discord_bot.disabled", reason="kein Token gesetzt") return state.init(settings.db_path) def _run() -> None: - bot = _AitraderBot(settings) - asyncio.run(bot.start(settings.discord_bot_token)) + global _bot_instance + _bot_instance = _AitraderBot(settings) + asyncio.run(_bot_instance.start(settings.discord_bot_token)) t = threading.Thread(target=_run, daemon=True, name="discord-bot") t.start() log.info("discord_bot.thread_started") + diff --git a/src/aitrader/notify/discord.py b/src/aitrader/notify/discord.py index 006a248..8560766 100644 --- a/src/aitrader/notify/discord.py +++ b/src/aitrader/notify/discord.py @@ -1,9 +1,11 @@ -"""Discord-Webhook-Notifier.""" +"""Discord-Webhook-Notifier (mit Bot-Support).""" from __future__ import annotations +import asyncio from datetime import datetime, timezone from typing import Any +import discord import requests from ..config import Settings @@ -23,7 +25,61 @@ _error_last_sent: dict[str, float] = {} # key -> timestamp def _enabled(settings: Settings) -> bool: - return bool(settings.discord.enabled and settings.discord_webhook_url) + return bool( + settings.discord.enabled + and (settings.discord_webhook_url or settings.discord_bot_token) + ) + + +def send_via_bot( + settings: Settings, + embed_dict: dict[str, Any], + channel_type: str = "trades", + content: str | None = None, +) -> bool: + from .bot import get_bot_instance + bot = get_bot_instance() + if bot is None or not bot.is_ready(): + return False + + channel_id_str = ( + settings.discord_decisions_channel_id + if channel_type == "decisions" + else settings.discord_channel_id + ) + if not channel_id_str: + log.warning("discord_bot.channel_not_configured", channel_type=channel_type) + return False + + try: + channel_id = int(channel_id_str) + except ValueError: + log.warning("discord_bot.invalid_channel_id", channel_id=channel_id_str) + return False + + embed = discord.Embed.from_dict(embed_dict) + + async def _send(): + ch = bot.get_channel(channel_id) + if not ch: + try: + ch = await bot.fetch_channel(channel_id) + except Exception as ex: + log.warning("discord_bot.fetch_channel_failed", channel_id=channel_id, error=str(ex)) + return + if isinstance(ch, discord.TextChannel): + await ch.send(content=content, embed=embed) + log.info("discord_bot.message_sent", channel_id=channel_id, channel_type=channel_type) + else: + log.warning("discord_bot.channel_not_text_channel", channel_id=channel_id) + + future = asyncio.run_coroutine_threadsafe(_send(), bot.loop) + try: + future.result(timeout=10.0) + return True + except Exception as e: + log.error("discord_bot.send_failed", error=str(e)) + return False def _post( @@ -34,11 +90,25 @@ def _post( ) -> None: if not _enabled(settings): return + + embed.setdefault("timestamp", datetime.now(timezone.utc).isoformat()) + + # Versuche zuerst über den Discord Bot zu senden + if settings.discord_bot_token: + sent = send_via_bot(settings, embed, channel_type=channel, content=content) + if sent: + return + log.warning("discord.bot_send_failed_falling_back_to_webhook") + + # Fallback zu Webhook if channel == "decisions" and settings.discord_webhook_decisions_url: url = settings.discord_webhook_decisions_url else: url = settings.discord_webhook_url - embed.setdefault("timestamp", datetime.now(timezone.utc).isoformat()) + + if not url: + return + payload: dict[str, Any] = {"embeds": [embed]} if content: payload["content"] = content @@ -50,6 +120,7 @@ def _post( log.warning("discord.post_exception", error=str(e)) + def _should(settings: Settings, event: str) -> bool: return _enabled(settings) and event in settings.discord.notify_on and state.is_enabled(event) diff --git a/src/aitrader/storage/db.py b/src/aitrader/storage/db.py index 422ee76..b104562 100644 --- a/src/aitrader/storage/db.py +++ b/src/aitrader/storage/db.py @@ -13,9 +13,10 @@ _engine = None def get_engine(db_path: str): global _engine - if _engine is None: + db_url = f"sqlite:///{db_path}" + if _engine is None or str(_engine.url) != db_url: Path(db_path).parent.mkdir(parents=True, exist_ok=True) - _engine = create_engine(f"sqlite:///{db_path}", echo=False) + _engine = create_engine(db_url, echo=False) SQLModel.metadata.create_all(_engine) return _engine diff --git a/src/aitrader/trader/portfolio.py b/src/aitrader/trader/portfolio.py index e8a8089..374a43f 100644 --- a/src/aitrader/trader/portfolio.py +++ b/src/aitrader/trader/portfolio.py @@ -2,15 +2,22 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Iterable +from typing import Iterable, TYPE_CHECKING from sqlmodel import select from ..config import Settings +from ..logging_setup import get_logger from ..notify import discord from ..storage import db as dbm from ..storage.models import EquitySnapshot, Trade +if TYPE_CHECKING: + from ..exchange.kraken import KrakenClient + +log = get_logger(__name__) + + def open_trades_for_symbol(settings: Settings, symbol: str) -> list[Trade]: with dbm.session(settings.db_path) as s: @@ -24,11 +31,22 @@ def all_open_trades(settings: Settings) -> list[Trade]: return list(s.exec(select(Trade).where(Trade.status == "open")).all()) -def close_trade(settings: Settings, trade_id: int, exit_price: float) -> None: +def close_trade(settings: Settings, kraken: KrakenClient, trade_id: int, exit_price: float) -> None: with dbm.session(settings.db_path) as s: t = s.get(Trade, trade_id) if not t or t.status != "open": return + + # Close on exchange first (will simulate if paper_only is true) + ccxt_side = "sell" if t.side == "buy" else "buy" + try: + order = kraken.create_market_order(t.symbol, ccxt_side, t.qty) + log.info("close_trade.exchange_success", symbol=t.symbol, side=ccxt_side, qty=t.qty, order_id=order.get("id")) + except Exception as e: + log.error("close_trade.exchange_failed", error=str(e), symbol=t.symbol, qty=t.qty) + discord.notify_error(settings, f"close_trade.exchange_failed ({t.symbol})", str(e)) + return + t.exit_price = exit_price t.exit_ts = datetime.now(timezone.utc) sign = 1 if t.side == "buy" else -1 @@ -41,7 +59,7 @@ def close_trade(settings: Settings, trade_id: int, exit_price: float) -> None: def check_stop_take_profit( - settings: Settings, symbol: str, current_price: float + settings: Settings, kraken: KrakenClient, symbol: str, current_price: float ) -> list[int]: """Schließt Trades, wenn SL/TP erreicht. Gibt geschlossene Trade-IDs zurück.""" closed: list[int] = [] @@ -58,8 +76,12 @@ def check_stop_take_profit( elif t.take_profit and current_price <= t.take_profit: hit = True if hit: - close_trade(settings, t.id, current_price) - closed.append(t.id) + close_trade(settings, kraken, t.id, current_price) + # Verify if it was successfully closed in the DB before counting it + with dbm.session(settings.db_path) as s: + updated_t = s.get(Trade, t.id) + if updated_t and updated_t.status == "closed": + closed.append(t.id) return closed diff --git a/src/aitrader/trader/risk.py b/src/aitrader/trader/risk.py index 21a0bf6..43a70af 100644 --- a/src/aitrader/trader/risk.py +++ b/src/aitrader/trader/risk.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone from sqlmodel import select diff --git a/tests/test_notify.py b/tests/test_notify.py new file mode 100644 index 0000000..5581a68 --- /dev/null +++ b/tests/test_notify.py @@ -0,0 +1,50 @@ +import asyncio +import pytest +from unittest.mock import MagicMock, patch + +from aitrader.config import Settings +from aitrader.notify import discord + + +def test_discord_enabled_with_webhook(): + s = Settings(discord=dict(enabled=True), discord_webhook_url="https://xyz") + assert discord._enabled(s) + + +def test_discord_enabled_with_bot(): + s = Settings(discord=dict(enabled=True), discord_bot_token="xyz") + assert discord._enabled(s) + + +def test_discord_disabled(): + s = Settings( + discord=dict(enabled=False), + discord_webhook_url="https://xyz", + discord_bot_token="xyz", + ) + assert not discord._enabled(s) + + +@patch("aitrader.notify.bot.get_bot_instance") +def test_send_via_bot_success(mock_get_bot): + mock_bot = MagicMock() + mock_bot.is_ready.return_value = True + + mock_channel = MagicMock() + mock_bot.get_channel.return_value = mock_channel + + mock_get_bot.return_value = mock_bot + + s = Settings(discord_channel_id="12345", discord_bot_token="token") + + def run_coro_sync(coro, loop): + asyncio.run(coro) + fut = MagicMock() + fut.result.return_value = None + return fut + + with patch("asyncio.run_coroutine_threadsafe", side_effect=run_coro_sync) as mock_run: + res = discord.send_via_bot(s, {"title": "Test"}, channel_type="trades") + assert res + mock_run.assert_called_once() + mock_bot.get_channel.assert_called_once_with(12345) diff --git a/tests/test_portfolio.py b/tests/test_portfolio.py new file mode 100644 index 0000000..76f8304 --- /dev/null +++ b/tests/test_portfolio.py @@ -0,0 +1,172 @@ +import pytest +from unittest.mock import MagicMock + +from aitrader.config import Settings +from aitrader.storage import db as dbm +from aitrader.storage.models import Trade +from aitrader.trader import portfolio + + +@pytest.fixture +def settings(tmp_path): + s = Settings(starting_equity_eur=10000.0, db_path=str(tmp_path / "t.db")) + # Initialize DB + dbm.get_engine(s.db_path) + return s + + +@pytest.fixture +def mock_kraken(): + kraken = MagicMock() + kraken.create_market_order.return_value = {"id": "mock_order_123"} + return kraken + + +def test_check_stop_take_profit_buy_sl_hit(settings, mock_kraken): + # Setup open buy trade + with dbm.session(settings.db_path) as s: + trade = Trade( + symbol="BTC/USD:USD", + side="buy", + qty=1.5, + entry_price=60000.0, + stop_loss=59000.0, + take_profit=62000.0, + status="open" + ) + s.add(trade) + s.commit() + s.refresh(trade) + trade_id = trade.id + + # Under SL (58500 <= 59000) + closed = portfolio.check_stop_take_profit(settings, mock_kraken, "BTC/USD:USD", 58500.0) + + assert closed == [trade_id] + mock_kraken.create_market_order.assert_called_once_with("BTC/USD:USD", "sell", 1.5) + + # Check DB update + with dbm.session(settings.db_path) as s: + t = s.get(Trade, trade_id) + assert t.status == "closed" + assert t.exit_price == 58500.0 + assert t.pnl_eur == (58500.0 - 60000.0) * 1.5 + + +def test_check_stop_take_profit_buy_tp_hit(settings, mock_kraken): + with dbm.session(settings.db_path) as s: + trade = Trade( + symbol="BTC/USD:USD", + side="buy", + qty=1.0, + entry_price=60000.0, + stop_loss=59000.0, + take_profit=62000.0, + status="open" + ) + s.add(trade) + s.commit() + s.refresh(trade) + trade_id = trade.id + + # Above TP (62500 >= 62000) + closed = portfolio.check_stop_take_profit(settings, mock_kraken, "BTC/USD:USD", 62500.0) + + assert closed == [trade_id] + mock_kraken.create_market_order.assert_called_once_with("BTC/USD:USD", "sell", 1.0) + + # Check DB update + with dbm.session(settings.db_path) as s: + t = s.get(Trade, trade_id) + assert t.status == "closed" + assert t.exit_price == 62500.0 + assert t.pnl_eur == (62500.0 - 60000.0) * 1.0 + + +def test_check_stop_take_profit_no_hit(settings, mock_kraken): + with dbm.session(settings.db_path) as s: + trade = Trade( + symbol="BTC/USD:USD", + side="buy", + qty=1.0, + entry_price=60000.0, + stop_loss=59000.0, + take_profit=62000.0, + status="open" + ) + s.add(trade) + s.commit() + s.refresh(trade) + trade_id = trade.id + + # Price between SL and TP + closed = portfolio.check_stop_take_profit(settings, mock_kraken, "BTC/USD:USD", 60500.0) + + assert closed == [] + mock_kraken.create_market_order.assert_not_called() + + # Check DB update + with dbm.session(settings.db_path) as s: + t = s.get(Trade, trade_id) + assert t.status == "open" + + +def test_check_stop_take_profit_sell_sl_hit(settings, mock_kraken): + # Setup open sell trade + with dbm.session(settings.db_path) as s: + trade = Trade( + symbol="ETH/USD:USD", + side="sell", + qty=2.0, + entry_price=3000.0, + stop_loss=3100.0, + take_profit=2800.0, + status="open" + ) + s.add(trade) + s.commit() + s.refresh(trade) + trade_id = trade.id + + # Above SL for Sell (3150 >= 3100) + closed = portfolio.check_stop_take_profit(settings, mock_kraken, "ETH/USD:USD", 3150.0) + + assert closed == [trade_id] + mock_kraken.create_market_order.assert_called_once_with("ETH/USD:USD", "buy", 2.0) + + # Check DB update + with dbm.session(settings.db_path) as s: + t = s.get(Trade, trade_id) + assert t.status == "closed" + assert t.exit_price == 3150.0 + assert t.pnl_eur == -1 * (3150.0 - 3000.0) * 2.0 + + +def test_check_stop_take_profit_exchange_fails(settings, mock_kraken): + mock_kraken.create_market_order.side_effect = Exception("Kraken API down") + + with dbm.session(settings.db_path) as s: + trade = Trade( + symbol="BTC/USD:USD", + side="buy", + qty=1.0, + entry_price=60000.0, + stop_loss=59000.0, + take_profit=62000.0, + status="open" + ) + s.add(trade) + s.commit() + s.refresh(trade) + trade_id = trade.id + + # Under SL, but exchange order fails + closed = portfolio.check_stop_take_profit(settings, mock_kraken, "BTC/USD:USD", 58500.0) + + assert closed == [] + mock_kraken.create_market_order.assert_called_once_with("BTC/USD:USD", "sell", 1.0) + + # DB record must still be open + with dbm.session(settings.db_path) as s: + t = s.get(Trade, trade_id) + assert t.status == "open"