fix: clean up orphaned nftables chains on reconcile
rebuild_all_rules now discovers existing user_ chains and removes any that are no longer in the DB. Reconcile always runs the firewall rebuild even with 0 devices, so stale forward rules and orphan chains are cleaned up when all devices are deleted.
This commit is contained in:
parent
0f5e517f9d
commit
260837d3aa
3 changed files with 97 additions and 9 deletions
|
|
@ -8,6 +8,7 @@ from wiregui.services.firewall import (
|
||||||
_nft,
|
_nft,
|
||||||
_nft_batch,
|
_nft_batch,
|
||||||
add_device_jump_rule,
|
add_device_jump_rule,
|
||||||
|
rebuild_all_rules,
|
||||||
setup_base_tables,
|
setup_base_tables,
|
||||||
setup_masquerade,
|
setup_masquerade,
|
||||||
apply_peer_to_peer_policy,
|
apply_peer_to_peer_policy,
|
||||||
|
|
@ -203,4 +204,63 @@ async def test_get_ruleset_returns_fallback_on_error(mock_nft):
|
||||||
"""get_ruleset returns friendly message when nft not available."""
|
"""get_ruleset returns friendly message when nft not available."""
|
||||||
mock_nft.side_effect = RuntimeError("nft not found")
|
mock_nft.side_effect = RuntimeError("nft not found")
|
||||||
result = await get_ruleset()
|
result = await get_ruleset()
|
||||||
assert "not available" in result
|
assert "not available" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ========== rebuild_all_rules — orphan cleanup ==========
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
@patch("wiregui.services.firewall._list_user_chains", new_callable=AsyncMock)
|
||||||
|
async def test_rebuild_removes_orphaned_user_chains(mock_list, mock_batch):
|
||||||
|
"""Orphaned user chains (in nft but not in DB) should be flushed and deleted."""
|
||||||
|
mock_list.return_value = {"user_aaaa00000000", "user_bbbb00000000"}
|
||||||
|
|
||||||
|
# Only user_aaaa is still in the DB
|
||||||
|
await rebuild_all_rules([{
|
||||||
|
"user_id": "aaaa0000-0000-0000-0000-000000000000",
|
||||||
|
"devices": [{"ipv4": "10.0.0.2", "ipv6": None}],
|
||||||
|
"rules": [],
|
||||||
|
}])
|
||||||
|
|
||||||
|
batch_cmds = mock_batch.call_args[0][0]
|
||||||
|
batch_text = "\n".join(batch_cmds)
|
||||||
|
# user_bbbb should be flushed and deleted
|
||||||
|
assert "flush chain inet wiregui user_bbbb00000000" in batch_text
|
||||||
|
assert "delete chain inet wiregui user_bbbb00000000" in batch_text
|
||||||
|
# user_aaaa should NOT be deleted
|
||||||
|
assert "delete chain inet wiregui user_aaaa00000000" not in batch_text
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
@patch("wiregui.services.firewall._list_user_chains", new_callable=AsyncMock)
|
||||||
|
async def test_rebuild_with_no_devices_clears_forward_and_orphans(mock_list, mock_batch):
|
||||||
|
"""With zero devices, forward chain should be flushed and all user chains removed."""
|
||||||
|
mock_list.return_value = {"user_aaaa00000000", "user_bbbb00000000"}
|
||||||
|
|
||||||
|
await rebuild_all_rules([])
|
||||||
|
|
||||||
|
batch_cmds = mock_batch.call_args[0][0]
|
||||||
|
batch_text = "\n".join(batch_cmds)
|
||||||
|
# Forward chain must be flushed even with no devices
|
||||||
|
assert "flush chain inet wiregui forward" in batch_text
|
||||||
|
# Both orphans removed
|
||||||
|
assert "delete chain inet wiregui user_aaaa00000000" in batch_text
|
||||||
|
assert "delete chain inet wiregui user_bbbb00000000" in batch_text
|
||||||
|
|
||||||
|
|
||||||
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
||||||
|
@patch("wiregui.services.firewall._list_user_chains", new_callable=AsyncMock)
|
||||||
|
async def test_rebuild_no_orphans_no_deletions(mock_list, mock_batch):
|
||||||
|
"""When all nft chains match the DB, no deletions should occur."""
|
||||||
|
mock_list.return_value = {"user_aaaa00000000"}
|
||||||
|
|
||||||
|
await rebuild_all_rules([{
|
||||||
|
"user_id": "aaaa0000-0000-0000-0000-000000000000",
|
||||||
|
"devices": [{"ipv4": "10.0.0.2", "ipv6": None}],
|
||||||
|
"rules": [],
|
||||||
|
}])
|
||||||
|
|
||||||
|
batch_cmds = mock_batch.call_args[0][0]
|
||||||
|
batch_text = "\n".join(batch_cmds)
|
||||||
|
assert "delete chain" not in batch_text
|
||||||
|
|
@ -129,10 +129,17 @@ async def apply_rule(user_id: str, destination: str, action: str, port_type: str
|
||||||
async def rebuild_all_rules(users_devices_rules: list[dict]) -> None:
|
async def rebuild_all_rules(users_devices_rules: list[dict]) -> None:
|
||||||
"""Full reconciliation: flush and rebuild all per-user chains from DB state.
|
"""Full reconciliation: flush and rebuild all per-user chains from DB state.
|
||||||
|
|
||||||
|
Removes orphaned user chains that are no longer in the DB.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
users_devices_rules: list of dicts with keys:
|
users_devices_rules: list of dicts with keys:
|
||||||
user_id, devices (list of {ipv4, ipv6}), rules (list of {destination, action, port_type, port_range})
|
user_id, devices (list of {ipv4, ipv6}), rules (list of {destination, action, port_type, port_range})
|
||||||
"""
|
"""
|
||||||
|
# Discover existing user_ chains so we can remove orphans
|
||||||
|
existing_user_chains = await _list_user_chains()
|
||||||
|
expected_chains = {_user_chain_name(e["user_id"]) for e in users_devices_rules}
|
||||||
|
orphaned_chains = existing_user_chains - expected_chains
|
||||||
|
|
||||||
commands = []
|
commands = []
|
||||||
|
|
||||||
for entry in users_devices_rules:
|
for entry in users_devices_rules:
|
||||||
|
|
@ -162,9 +169,16 @@ async def rebuild_all_rules(users_devices_rules: list[dict]) -> None:
|
||||||
if dev.get("ipv6"):
|
if dev.get("ipv6"):
|
||||||
commands.append(f"add rule inet {TABLE_NAME} forward ip6 saddr {dev['ipv6']} jump {chain}")
|
commands.append(f"add rule inet {TABLE_NAME} forward ip6 saddr {dev['ipv6']} jump {chain}")
|
||||||
|
|
||||||
if commands:
|
# Remove orphaned user chains (must happen after forward chain is flushed
|
||||||
await _nft_batch(commands)
|
# so there are no remaining jump references to these chains)
|
||||||
logger.info("Firewall rules rebuilt for {} users", len(users_devices_rules))
|
for chain in orphaned_chains:
|
||||||
|
commands.append(f"flush chain inet {TABLE_NAME} {chain}")
|
||||||
|
commands.append(f"delete chain inet {TABLE_NAME} {chain}")
|
||||||
|
|
||||||
|
await _nft_batch(commands)
|
||||||
|
if orphaned_chains:
|
||||||
|
logger.info("Removed {} orphaned firewall chain(s): {}", len(orphaned_chains), orphaned_chains)
|
||||||
|
logger.info("Firewall rules rebuilt for {} users", len(users_devices_rules))
|
||||||
|
|
||||||
|
|
||||||
async def apply_peer_to_peer_policy(enabled: bool) -> None:
|
async def apply_peer_to_peer_policy(enabled: bool) -> None:
|
||||||
|
|
@ -235,6 +249,21 @@ async def get_ruleset() -> str:
|
||||||
return "nftables is not available.\n\nThis requires root/NET_ADMIN privileges (production container)."
|
return "nftables is not available.\n\nThis requires root/NET_ADMIN privileges (production container)."
|
||||||
|
|
||||||
|
|
||||||
|
async def _list_user_chains() -> set[str]:
|
||||||
|
"""Return the set of user_ chain names currently in the wiregui table."""
|
||||||
|
try:
|
||||||
|
output = await _nft(f"list table inet {TABLE_NAME}")
|
||||||
|
except RuntimeError:
|
||||||
|
return set()
|
||||||
|
chains = set()
|
||||||
|
for line in output.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("chain user_"):
|
||||||
|
name = line.split()[1]
|
||||||
|
chains.add(name)
|
||||||
|
return chains
|
||||||
|
|
||||||
|
|
||||||
def _user_chain_name(user_id: str) -> str:
|
def _user_chain_name(user_id: str) -> str:
|
||||||
"""Generate a deterministic chain name from a user ID."""
|
"""Generate a deterministic chain name from a user ID."""
|
||||||
# Use first 12 chars of UUID (without hyphens) to keep names short
|
# Use first 12 chars of UUID (without hyphens) to keep names short
|
||||||
|
|
|
||||||
|
|
@ -83,8 +83,7 @@ async def _reconcile_firewall(devices: list[Device], rules: list[Rule]) -> None:
|
||||||
],
|
],
|
||||||
})
|
})
|
||||||
|
|
||||||
if entries:
|
try:
|
||||||
try:
|
await firewall.rebuild_all_rules(entries)
|
||||||
await firewall.rebuild_all_rules(entries)
|
except Exception as e:
|
||||||
except Exception as e:
|
logger.error("Reconcile: firewall rebuild failed: {}", e)
|
||||||
logger.error("Reconcile: firewall rebuild failed: {}", e)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue