diff --git a/wiregui/utils/network.py b/wiregui/utils/network.py index 56f5266..3bb505b 100644 --- a/wiregui/utils/network.py +++ b/wiregui/utils/network.py @@ -1,7 +1,7 @@ """IP address allocation for WireGuard tunnel addresses.""" import random -from ipaddress import IPv4Network, IPv6Network, ip_address +from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession @@ -11,17 +11,17 @@ from wiregui.models.device import Device async def allocate_ipv4(session: AsyncSession, network_cidr: str) -> str: - """Find the next available IPv4 address in the given CIDR range.""" + """Find an available IPv4 address in the given CIDR range.""" network = IPv4Network(network_cidr, strict=False) used = await _get_used_ips(session, "ipv4") - return _find_available(network, used) + return _find_available_v4(network, used) async def allocate_ipv6(session: AsyncSession, network_cidr: str) -> str: - """Find the next available IPv6 address in the given CIDR range.""" + """Find an available IPv6 address in the given CIDR range.""" network = IPv6Network(network_cidr, strict=False) used = await _get_used_ips(session, "ipv6") - return _find_available(network, used) + return _find_available_v6(network, used) async def _get_used_ips(session: AsyncSession, field: str) -> set[str]: @@ -31,30 +31,54 @@ async def _get_used_ips(session: AsyncSession, field: str) -> set[str]: return {row[0] for row in result.all()} -def _find_available(network: IPv4Network | IPv6Network, used: set[str]) -> str: - """Find an available IP in the network, starting from a random offset.""" - hosts = list(network.hosts()) - if not hosts: +def _find_available_v4(network: IPv4Network, used: set[str]) -> str: + """Find an available IPv4 by random sampling — O(1) per attempt, no list materialization.""" + # Usable range: network_address + 2 to broadcast - 1 (skip network, gateway, broadcast) + first = int(network.network_address) + 2 + last = int(network.broadcast_address) - 1 + pool_size = last - first + 1 + + if pool_size <= 0: raise ValueError(f"No usable hosts in {network}") + if len(used) >= pool_size: + raise ValueError(f"No available addresses in {network}") - # Skip the first host (gateway/server address) - hosts = hosts[1:] - if not hosts: - raise ValueError(f"No usable hosts in {network} after reserving gateway") - - # Start from a random offset, then scan forward and backward - start = random.randint(0, len(hosts) - 1) - - # Forward scan - for i in range(start, len(hosts)): - candidate = str(hosts[i]) + for _ in range(min(pool_size, 1000)): + candidate = str(IPv4Address(random.randint(first, last))) if candidate not in used: logger.debug("Allocated {} from {}", candidate, network) return candidate - # Backward scan - for i in range(start - 1, -1, -1): - candidate = str(hosts[i]) + # Fallback: sequential scan (only if random sampling keeps hitting used IPs) + for offset in range(pool_size): + candidate = str(IPv4Address(first + offset)) + if candidate not in used: + logger.debug("Allocated {} from {}", candidate, network) + return candidate + + raise ValueError(f"No available addresses in {network}") + + +def _find_available_v6(network: IPv6Network, used: set[str]) -> str: + """Find an available IPv6 by random sampling.""" + first = int(network.network_address) + 2 + last = int(network.broadcast_address) - 1 + pool_size = last - first + 1 + + if pool_size <= 0: + raise ValueError(f"No usable hosts in {network}") + if len(used) >= pool_size: + raise ValueError(f"No available addresses in {network}") + + for _ in range(min(pool_size, 1000)): + candidate = str(IPv6Address(random.randint(first, last))) + if candidate not in used: + logger.debug("Allocated {} from {}", candidate, network) + return candidate + + # Fallback: sequential scan + for offset in range(pool_size): + candidate = str(IPv6Address(first + offset)) if candidate not in used: logger.debug("Allocated {} from {}", candidate, network) return candidate