rebuild_all_rules now discovers existing user_ chains and removes any that are no longer in the DB. Reconcile always runs the firewall rebuild even with 0 devices, so stale forward rules and orphan chains are cleaned up when all devices are deleted.
266 lines
No EOL
9.6 KiB
Python
266 lines
No EOL
9.6 KiB
Python
"""Extended firewall tests — _nft/_nft_batch error handling, add_device_jump_rule edge cases, policies."""
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from wiregui.services.firewall import (
|
|
_nft,
|
|
_nft_batch,
|
|
add_device_jump_rule,
|
|
rebuild_all_rules,
|
|
setup_base_tables,
|
|
setup_masquerade,
|
|
apply_peer_to_peer_policy,
|
|
apply_lan_to_peers_policy,
|
|
get_ruleset,
|
|
)
|
|
|
|
|
|
# ========== _nft error handling ==========
|
|
|
|
|
|
@patch("asyncio.create_subprocess_exec")
|
|
async def test_nft_raises_on_failure(mock_exec):
|
|
"""_nft raises RuntimeError on non-zero exit code."""
|
|
mock_proc = AsyncMock()
|
|
mock_proc.communicate.return_value = (b"", b"nft: error message")
|
|
mock_proc.returncode = 1
|
|
mock_exec.return_value = mock_proc
|
|
|
|
with pytest.raises(RuntimeError, match="nft.*failed"):
|
|
await _nft("list ruleset")
|
|
|
|
|
|
@patch("asyncio.create_subprocess_exec")
|
|
async def test_nft_returns_stdout_on_success(mock_exec):
|
|
"""_nft returns stdout on success."""
|
|
mock_proc = AsyncMock()
|
|
mock_proc.communicate.return_value = (b"table inet wiregui {}", b"")
|
|
mock_proc.returncode = 0
|
|
mock_exec.return_value = mock_proc
|
|
|
|
result = await _nft("list ruleset")
|
|
assert "wiregui" in result
|
|
|
|
|
|
# ========== _nft_batch error handling ==========
|
|
|
|
|
|
@patch("asyncio.create_subprocess_exec")
|
|
async def test_nft_batch_raises_on_failure(mock_exec):
|
|
"""_nft_batch raises RuntimeError on non-zero exit code."""
|
|
mock_proc = AsyncMock()
|
|
mock_proc.communicate.return_value = (b"", b"Error: syntax error")
|
|
mock_proc.returncode = 1
|
|
mock_exec.return_value = mock_proc
|
|
|
|
with pytest.raises(RuntimeError, match="nft batch failed"):
|
|
await _nft_batch(["add table inet wiregui"])
|
|
|
|
|
|
@patch("asyncio.create_subprocess_exec")
|
|
async def test_nft_batch_sends_commands_via_stdin(mock_exec):
|
|
"""_nft_batch sends all commands via stdin to nft -f -."""
|
|
mock_proc = AsyncMock()
|
|
mock_proc.communicate.return_value = (b"", b"")
|
|
mock_proc.returncode = 0
|
|
mock_exec.return_value = mock_proc
|
|
|
|
cmds = ["add table inet wiregui", "add chain inet wiregui test"]
|
|
await _nft_batch(cmds)
|
|
|
|
mock_exec.assert_awaited_once()
|
|
# Verify nft -f - was called
|
|
call_args = mock_exec.call_args[0]
|
|
assert call_args == ("nft", "-f", "-")
|
|
# Verify stdin data
|
|
stdin_data = mock_proc.communicate.call_args[0][0]
|
|
assert b"add table inet wiregui" in stdin_data
|
|
assert b"add chain inet wiregui test" in stdin_data
|
|
|
|
|
|
# ========== add_device_jump_rule edge cases ==========
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_add_device_jump_rule_ipv4_only(mock_batch):
|
|
"""Only IPv4 — generates single IPv4 jump rule."""
|
|
await add_device_jump_rule("user-id-1", "10.0.0.5", None)
|
|
mock_batch.assert_awaited_once()
|
|
cmds = mock_batch.call_args[0][0]
|
|
assert len(cmds) == 1
|
|
assert "ip saddr 10.0.0.5" in cmds[0]
|
|
assert "jump" in cmds[0]
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_add_device_jump_rule_ipv6_only(mock_batch):
|
|
"""Only IPv6 — generates single IPv6 jump rule."""
|
|
await add_device_jump_rule("user-id-2", None, "fd00::5")
|
|
mock_batch.assert_awaited_once()
|
|
cmds = mock_batch.call_args[0][0]
|
|
assert len(cmds) == 1
|
|
assert "ip6 saddr fd00::5" in cmds[0]
|
|
assert "jump" in cmds[0]
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_add_device_jump_rule_no_ips(mock_batch):
|
|
"""Neither IPv4 nor IPv6 — no nft commands issued."""
|
|
await add_device_jump_rule("user-id-3", None, None)
|
|
mock_batch.assert_not_awaited()
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_add_device_jump_rule_both_ips(mock_batch):
|
|
"""Both IPv4 and IPv6 — generates two jump rules."""
|
|
await add_device_jump_rule("user-id-4", "10.0.0.7", "fd00::7")
|
|
mock_batch.assert_awaited_once()
|
|
cmds = mock_batch.call_args[0][0]
|
|
assert len(cmds) == 2
|
|
assert any("ip saddr 10.0.0.7" in c for c in cmds)
|
|
assert any("ip6 saddr fd00::7" in c for c in cmds)
|
|
|
|
|
|
# ========== setup_base_tables — already exists ==========
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_setup_base_tables_already_exists(mock_batch):
|
|
"""If table already exists (File exists error), don't raise."""
|
|
mock_batch.side_effect = RuntimeError("File exists")
|
|
await setup_base_tables() # should not raise
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_setup_base_tables_other_error_raises(mock_batch):
|
|
"""Other nft errors should propagate."""
|
|
mock_batch.side_effect = RuntimeError("Permission denied")
|
|
with pytest.raises(RuntimeError, match="Permission denied"):
|
|
await setup_base_tables()
|
|
|
|
|
|
# ========== setup_masquerade — error handling ==========
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_setup_masquerade_error_swallowed(mock_batch):
|
|
"""Masquerade errors are logged but not raised."""
|
|
mock_batch.side_effect = RuntimeError("nft error")
|
|
await setup_masquerade(iface="wg0") # should not raise
|
|
|
|
|
|
# ========== policy functions — command verification ==========
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_peer_to_peer_enabled(mock_batch):
|
|
"""Enabling peer-to-peer generates accept rules."""
|
|
await apply_peer_to_peer_policy(True)
|
|
cmds = mock_batch.call_args[0][0]
|
|
assert any("accept" in c for c in cmds)
|
|
assert any("peer_to_peer" in c for c in cmds)
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_peer_to_peer_disabled(mock_batch):
|
|
"""Disabling peer-to-peer generates drop rules."""
|
|
await apply_peer_to_peer_policy(False)
|
|
cmds = mock_batch.call_args[0][0]
|
|
assert any("drop" in c for c in cmds)
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_lan_to_peers_enabled(mock_batch):
|
|
"""Enabling LAN-to-peers generates accept rules."""
|
|
await apply_lan_to_peers_policy(True)
|
|
cmds = mock_batch.call_args[0][0]
|
|
assert any("accept" in c for c in cmds)
|
|
assert any("lan_to_peers" in c for c in cmds)
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
async def test_lan_to_peers_disabled(mock_batch):
|
|
"""Disabling LAN-to-peers generates drop rules."""
|
|
await apply_lan_to_peers_policy(False)
|
|
cmds = mock_batch.call_args[0][0]
|
|
assert any("drop" in c for c in cmds)
|
|
|
|
|
|
# ========== get_ruleset — error handling ==========
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft", new_callable=AsyncMock)
|
|
async def test_get_ruleset_returns_output(mock_nft):
|
|
"""get_ruleset returns nft list ruleset output."""
|
|
mock_nft.return_value = "table inet wiregui { ... }"
|
|
result = await get_ruleset()
|
|
assert "wiregui" in result
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft", new_callable=AsyncMock)
|
|
async def test_get_ruleset_returns_fallback_on_error(mock_nft):
|
|
"""get_ruleset returns friendly message when nft not available."""
|
|
mock_nft.side_effect = RuntimeError("nft not found")
|
|
result = await get_ruleset()
|
|
assert "not available" in result
|
|
|
|
|
|
# ========== rebuild_all_rules — orphan cleanup ==========
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
@patch("wiregui.services.firewall._list_user_chains", new_callable=AsyncMock)
|
|
async def test_rebuild_removes_orphaned_user_chains(mock_list, mock_batch):
|
|
"""Orphaned user chains (in nft but not in DB) should be flushed and deleted."""
|
|
mock_list.return_value = {"user_aaaa00000000", "user_bbbb00000000"}
|
|
|
|
# Only user_aaaa is still in the DB
|
|
await rebuild_all_rules([{
|
|
"user_id": "aaaa0000-0000-0000-0000-000000000000",
|
|
"devices": [{"ipv4": "10.0.0.2", "ipv6": None}],
|
|
"rules": [],
|
|
}])
|
|
|
|
batch_cmds = mock_batch.call_args[0][0]
|
|
batch_text = "\n".join(batch_cmds)
|
|
# user_bbbb should be flushed and deleted
|
|
assert "flush chain inet wiregui user_bbbb00000000" in batch_text
|
|
assert "delete chain inet wiregui user_bbbb00000000" in batch_text
|
|
# user_aaaa should NOT be deleted
|
|
assert "delete chain inet wiregui user_aaaa00000000" not in batch_text
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
@patch("wiregui.services.firewall._list_user_chains", new_callable=AsyncMock)
|
|
async def test_rebuild_with_no_devices_clears_forward_and_orphans(mock_list, mock_batch):
|
|
"""With zero devices, forward chain should be flushed and all user chains removed."""
|
|
mock_list.return_value = {"user_aaaa00000000", "user_bbbb00000000"}
|
|
|
|
await rebuild_all_rules([])
|
|
|
|
batch_cmds = mock_batch.call_args[0][0]
|
|
batch_text = "\n".join(batch_cmds)
|
|
# Forward chain must be flushed even with no devices
|
|
assert "flush chain inet wiregui forward" in batch_text
|
|
# Both orphans removed
|
|
assert "delete chain inet wiregui user_aaaa00000000" in batch_text
|
|
assert "delete chain inet wiregui user_bbbb00000000" in batch_text
|
|
|
|
|
|
@patch("wiregui.services.firewall._nft_batch", new_callable=AsyncMock)
|
|
@patch("wiregui.services.firewall._list_user_chains", new_callable=AsyncMock)
|
|
async def test_rebuild_no_orphans_no_deletions(mock_list, mock_batch):
|
|
"""When all nft chains match the DB, no deletions should occur."""
|
|
mock_list.return_value = {"user_aaaa00000000"}
|
|
|
|
await rebuild_all_rules([{
|
|
"user_id": "aaaa0000-0000-0000-0000-000000000000",
|
|
"devices": [{"ipv4": "10.0.0.2", "ipv6": None}],
|
|
"rules": [],
|
|
}])
|
|
|
|
batch_cmds = mock_batch.call_args[0][0]
|
|
batch_text = "\n".join(batch_cmds)
|
|
assert "delete chain" not in batch_text |