Spaces:
Running
Running
from typing import Optional, Union | |
from urllib.parse import urlparse | |
import dns.asyncbackend | |
import dns.asyncquery | |
import dns.inet | |
import dns.message | |
import dns.query | |
class Nameserver: | |
def __init__(self): | |
pass | |
def __str__(self): | |
raise NotImplementedError | |
def kind(self) -> str: | |
raise NotImplementedError | |
def is_always_max_size(self) -> bool: | |
raise NotImplementedError | |
def answer_nameserver(self) -> str: | |
raise NotImplementedError | |
def answer_port(self) -> int: | |
raise NotImplementedError | |
def query( | |
self, | |
request: dns.message.QueryMessage, | |
timeout: float, | |
source: Optional[str], | |
source_port: int, | |
max_size: bool, | |
one_rr_per_rrset: bool = False, | |
ignore_trailing: bool = False, | |
) -> dns.message.Message: | |
raise NotImplementedError | |
async def async_query( | |
self, | |
request: dns.message.QueryMessage, | |
timeout: float, | |
source: Optional[str], | |
source_port: int, | |
max_size: bool, | |
backend: dns.asyncbackend.Backend, | |
one_rr_per_rrset: bool = False, | |
ignore_trailing: bool = False, | |
) -> dns.message.Message: | |
raise NotImplementedError | |
class AddressAndPortNameserver(Nameserver): | |
def __init__(self, address: str, port: int): | |
super().__init__() | |
self.address = address | |
self.port = port | |
def kind(self) -> str: | |
raise NotImplementedError | |
def is_always_max_size(self) -> bool: | |
return False | |
def __str__(self): | |
ns_kind = self.kind() | |
return f"{ns_kind}:{self.address}@{self.port}" | |
def answer_nameserver(self) -> str: | |
return self.address | |
def answer_port(self) -> int: | |
return self.port | |
class Do53Nameserver(AddressAndPortNameserver): | |
def __init__(self, address: str, port: int = 53): | |
super().__init__(address, port) | |
def kind(self): | |
return "Do53" | |
def query( | |
self, | |
request: dns.message.QueryMessage, | |
timeout: float, | |
source: Optional[str], | |
source_port: int, | |
max_size: bool, | |
one_rr_per_rrset: bool = False, | |
ignore_trailing: bool = False, | |
) -> dns.message.Message: | |
if max_size: | |
response = dns.query.tcp( | |
request, | |
self.address, | |
timeout=timeout, | |
port=self.port, | |
source=source, | |
source_port=source_port, | |
one_rr_per_rrset=one_rr_per_rrset, | |
ignore_trailing=ignore_trailing, | |
) | |
else: | |
response = dns.query.udp( | |
request, | |
self.address, | |
timeout=timeout, | |
port=self.port, | |
source=source, | |
source_port=source_port, | |
raise_on_truncation=True, | |
one_rr_per_rrset=one_rr_per_rrset, | |
ignore_trailing=ignore_trailing, | |
ignore_errors=True, | |
ignore_unexpected=True, | |
) | |
return response | |
async def async_query( | |
self, | |
request: dns.message.QueryMessage, | |
timeout: float, | |
source: Optional[str], | |
source_port: int, | |
max_size: bool, | |
backend: dns.asyncbackend.Backend, | |
one_rr_per_rrset: bool = False, | |
ignore_trailing: bool = False, | |
) -> dns.message.Message: | |
if max_size: | |
response = await dns.asyncquery.tcp( | |
request, | |
self.address, | |
timeout=timeout, | |
port=self.port, | |
source=source, | |
source_port=source_port, | |
backend=backend, | |
one_rr_per_rrset=one_rr_per_rrset, | |
ignore_trailing=ignore_trailing, | |
) | |
else: | |
response = await dns.asyncquery.udp( | |
request, | |
self.address, | |
timeout=timeout, | |
port=self.port, | |
source=source, | |
source_port=source_port, | |
raise_on_truncation=True, | |
backend=backend, | |
one_rr_per_rrset=one_rr_per_rrset, | |
ignore_trailing=ignore_trailing, | |
ignore_errors=True, | |
ignore_unexpected=True, | |
) | |
return response | |
class DoHNameserver(Nameserver): | |
def __init__( | |
self, | |
url: str, | |
bootstrap_address: Optional[str] = None, | |
verify: Union[bool, str] = True, | |
want_get: bool = False, | |
): | |
super().__init__() | |
self.url = url | |
self.bootstrap_address = bootstrap_address | |
self.verify = verify | |
self.want_get = want_get | |
def kind(self): | |
return "DoH" | |
def is_always_max_size(self) -> bool: | |
return True | |
def __str__(self): | |
return self.url | |
def answer_nameserver(self) -> str: | |
return self.url | |
def answer_port(self) -> int: | |
port = urlparse(self.url).port | |
if port is None: | |
port = 443 | |
return port | |
def query( | |
self, | |
request: dns.message.QueryMessage, | |
timeout: float, | |
source: Optional[str], | |
source_port: int, | |
max_size: bool = False, | |
one_rr_per_rrset: bool = False, | |
ignore_trailing: bool = False, | |
) -> dns.message.Message: | |
return dns.query.https( | |
request, | |
self.url, | |
timeout=timeout, | |
source=source, | |
source_port=source_port, | |
bootstrap_address=self.bootstrap_address, | |
one_rr_per_rrset=one_rr_per_rrset, | |
ignore_trailing=ignore_trailing, | |
verify=self.verify, | |
post=(not self.want_get), | |
) | |
async def async_query( | |
self, | |
request: dns.message.QueryMessage, | |
timeout: float, | |
source: Optional[str], | |
source_port: int, | |
max_size: bool, | |
backend: dns.asyncbackend.Backend, | |
one_rr_per_rrset: bool = False, | |
ignore_trailing: bool = False, | |
) -> dns.message.Message: | |
return await dns.asyncquery.https( | |
request, | |
self.url, | |
timeout=timeout, | |
source=source, | |
source_port=source_port, | |
bootstrap_address=self.bootstrap_address, | |
one_rr_per_rrset=one_rr_per_rrset, | |
ignore_trailing=ignore_trailing, | |
verify=self.verify, | |
post=(not self.want_get), | |
) | |
class DoTNameserver(AddressAndPortNameserver): | |
def __init__( | |
self, | |
address: str, | |
port: int = 853, | |
hostname: Optional[str] = None, | |
verify: Union[bool, str] = True, | |
): | |
super().__init__(address, port) | |
self.hostname = hostname | |
self.verify = verify | |
def kind(self): | |
return "DoT" | |
def query( | |
self, | |
request: dns.message.QueryMessage, | |
timeout: float, | |
source: Optional[str], | |
source_port: int, | |
max_size: bool = False, | |
one_rr_per_rrset: bool = False, | |
ignore_trailing: bool = False, | |
) -> dns.message.Message: | |
return dns.query.tls( | |
request, | |
self.address, | |
port=self.port, | |
timeout=timeout, | |
one_rr_per_rrset=one_rr_per_rrset, | |
ignore_trailing=ignore_trailing, | |
server_hostname=self.hostname, | |
verify=self.verify, | |
) | |
async def async_query( | |
self, | |
request: dns.message.QueryMessage, | |
timeout: float, | |
source: Optional[str], | |
source_port: int, | |
max_size: bool, | |
backend: dns.asyncbackend.Backend, | |
one_rr_per_rrset: bool = False, | |
ignore_trailing: bool = False, | |
) -> dns.message.Message: | |
return await dns.asyncquery.tls( | |
request, | |
self.address, | |
port=self.port, | |
timeout=timeout, | |
one_rr_per_rrset=one_rr_per_rrset, | |
ignore_trailing=ignore_trailing, | |
server_hostname=self.hostname, | |
verify=self.verify, | |
) | |
class DoQNameserver(AddressAndPortNameserver): | |
def __init__( | |
self, | |
address: str, | |
port: int = 853, | |
verify: Union[bool, str] = True, | |
server_hostname: Optional[str] = None, | |
): | |
super().__init__(address, port) | |
self.verify = verify | |
self.server_hostname = server_hostname | |
def kind(self): | |
return "DoQ" | |
def query( | |
self, | |
request: dns.message.QueryMessage, | |
timeout: float, | |
source: Optional[str], | |
source_port: int, | |
max_size: bool = False, | |
one_rr_per_rrset: bool = False, | |
ignore_trailing: bool = False, | |
) -> dns.message.Message: | |
return dns.query.quic( | |
request, | |
self.address, | |
port=self.port, | |
timeout=timeout, | |
one_rr_per_rrset=one_rr_per_rrset, | |
ignore_trailing=ignore_trailing, | |
verify=self.verify, | |
server_hostname=self.server_hostname, | |
) | |
async def async_query( | |
self, | |
request: dns.message.QueryMessage, | |
timeout: float, | |
source: Optional[str], | |
source_port: int, | |
max_size: bool, | |
backend: dns.asyncbackend.Backend, | |
one_rr_per_rrset: bool = False, | |
ignore_trailing: bool = False, | |
) -> dns.message.Message: | |
return await dns.asyncquery.quic( | |
request, | |
self.address, | |
port=self.port, | |
timeout=timeout, | |
one_rr_per_rrset=one_rr_per_rrset, | |
ignore_trailing=ignore_trailing, | |
verify=self.verify, | |
server_hostname=self.server_hostname, | |
) | |