From 260837d3aa57b1b68fb7c1887a0d93da4ad6e1f4 Mon Sep 17 00:00:00 2001 From: Stefano Bertelli Date: Tue, 31 Mar 2026 23:25:30 -0500 Subject: [PATCH] fix: clean up orphaned nftables chains on reconcile rebuild_all_rules now discovers existing user_ chains and removes any that are no longer in the DB. Reconcile always runs the firewall rebuild even with 0 devices, so stale forward rules and orphan chains are cleaned up when all devices are deleted. --- tests/test_firewall_extended.py | 62 ++++++++++++++++++++++++++++++++- wiregui/services/firewall.py | 35 +++++++++++++++++-- wiregui/tasks/reconcile.py | 9 +++-- 3 files changed, 97 insertions(+), 9 deletions(-) diff --git a/tests/test_firewall_extended.py b/tests/test_firewall_extended.py index 08a8df3..db550b2 100644 --- a/tests/test_firewall_extended.py +++ b/tests/test_firewall_extended.py @@ -8,6 +8,7 @@ from wiregui.services.firewall import ( _nft, _nft_batch, add_device_jump_rule, + rebuild_all_rules, setup_base_tables, setup_masquerade, apply_peer_to_peer_policy, @@ -203,4 +204,63 @@ async def test_get_ruleset_returns_fallback_on_error(mock_nft): """get_ruleset returns friendly message when nft not available.""" mock_nft.side_effect = RuntimeError("nft not found") result = await get_ruleset() - assert "not available" in result \ No newline at end of file + assert "not available" in result + + +# ========== rebuild_all_rules — orphan cleanup ========== + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +@patch("wiregui.services.firewall._list_user_chains", new_callable=AsyncMock) +async def test_rebuild_removes_orphaned_user_chains(mock_list, mock_batch): + """Orphaned user chains (in nft but not in DB) should be flushed and deleted.""" + mock_list.return_value = {"user_aaaa00000000", "user_bbbb00000000"} + + # Only user_aaaa is still in the DB + await rebuild_all_rules([{ + "user_id": "aaaa0000-0000-0000-0000-000000000000", + "devices": [{"ipv4": "10.0.0.2", "ipv6": None}], + "rules": [], + }]) + + batch_cmds = mock_batch.call_args[0][0] + batch_text = "\n".join(batch_cmds) + # user_bbbb should be flushed and deleted + assert "flush chain inet wiregui user_bbbb00000000" in batch_text + assert "delete chain inet wiregui user_bbbb00000000" in batch_text + # user_aaaa should NOT be deleted + assert "delete chain inet wiregui user_aaaa00000000" not in batch_text + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +@patch("wiregui.services.firewall._list_user_chains", new_callable=AsyncMock) +async def test_rebuild_with_no_devices_clears_forward_and_orphans(mock_list, mock_batch): + """With zero devices, forward chain should be flushed and all user chains removed.""" + mock_list.return_value = {"user_aaaa00000000", "user_bbbb00000000"} + + await rebuild_all_rules([]) + + batch_cmds = mock_batch.call_args[0][0] + batch_text = "\n".join(batch_cmds) + # Forward chain must be flushed even with no devices + assert "flush chain inet wiregui forward" in batch_text + # Both orphans removed + assert "delete chain inet wiregui user_aaaa00000000" in batch_text + assert "delete chain inet wiregui user_bbbb00000000" in batch_text + + +@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock) +@patch("wiregui.services.firewall._list_user_chains", new_callable=AsyncMock) +async def test_rebuild_no_orphans_no_deletions(mock_list, mock_batch): + """When all nft chains match the DB, no deletions should occur.""" + mock_list.return_value = {"user_aaaa00000000"} + + await rebuild_all_rules([{ + "user_id": "aaaa0000-0000-0000-0000-000000000000", + "devices": [{"ipv4": "10.0.0.2", "ipv6": None}], + "rules": [], + }]) + + batch_cmds = mock_batch.call_args[0][0] + batch_text = "\n".join(batch_cmds) + assert "delete chain" not in batch_text \ No newline at end of file diff --git a/wiregui/services/firewall.py b/wiregui/services/firewall.py index 489afdc..ea7c0d0 100644 --- a/wiregui/services/firewall.py +++ b/wiregui/services/firewall.py @@ -129,10 +129,17 @@ async def apply_rule(user_id: str, destination: str, action: str, port_type: str async def rebuild_all_rules(users_devices_rules: list[dict]) -> None: """Full reconciliation: flush and rebuild all per-user chains from DB state. + Removes orphaned user chains that are no longer in the DB. + Args: users_devices_rules: list of dicts with keys: user_id, devices (list of {ipv4, ipv6}), rules (list of {destination, action, port_type, port_range}) """ + # Discover existing user_ chains so we can remove orphans + existing_user_chains = await _list_user_chains() + expected_chains = {_user_chain_name(e["user_id"]) for e in users_devices_rules} + orphaned_chains = existing_user_chains - expected_chains + commands = [] for entry in users_devices_rules: @@ -162,9 +169,16 @@ async def rebuild_all_rules(users_devices_rules: list[dict]) -> None: if dev.get("ipv6"): commands.append(f"add rule inet {TABLE_NAME} forward ip6 saddr {dev['ipv6']} jump {chain}") - if commands: - await _nft_batch(commands) - logger.info("Firewall rules rebuilt for {} users", len(users_devices_rules)) + # Remove orphaned user chains (must happen after forward chain is flushed + # so there are no remaining jump references to these chains) + for chain in orphaned_chains: + commands.append(f"flush chain inet {TABLE_NAME} {chain}") + commands.append(f"delete chain inet {TABLE_NAME} {chain}") + + await _nft_batch(commands) + if orphaned_chains: + logger.info("Removed {} orphaned firewall chain(s): {}", len(orphaned_chains), orphaned_chains) + logger.info("Firewall rules rebuilt for {} users", len(users_devices_rules)) async def apply_peer_to_peer_policy(enabled: bool) -> None: @@ -235,6 +249,21 @@ async def get_ruleset() -> str: return "nftables is not available.\n\nThis requires root/NET_ADMIN privileges (production container)." +async def _list_user_chains() -> set[str]: + """Return the set of user_ chain names currently in the wiregui table.""" + try: + output = await _nft(f"list table inet {TABLE_NAME}") + except RuntimeError: + return set() + chains = set() + for line in output.splitlines(): + line = line.strip() + if line.startswith("chain user_"): + name = line.split()[1] + chains.add(name) + return chains + + def _user_chain_name(user_id: str) -> str: """Generate a deterministic chain name from a user ID.""" # Use first 12 chars of UUID (without hyphens) to keep names short diff --git a/wiregui/tasks/reconcile.py b/wiregui/tasks/reconcile.py index 3ed8773..3c5634c 100644 --- a/wiregui/tasks/reconcile.py +++ b/wiregui/tasks/reconcile.py @@ -83,8 +83,7 @@ async def _reconcile_firewall(devices: list[Device], rules: list[Rule]) -> None: ], }) - if entries: - try: - await firewall.rebuild_all_rules(entries) - except Exception as e: - logger.error("Reconcile: firewall rebuild failed: {}", e) + try: + await firewall.rebuild_all_rules(entries) + except Exception as e: + logger.error("Reconcile: firewall rebuild failed: {}", e)