diff options
Diffstat (limited to 'pkb_client/client')
| -rw-r--r-- | pkb_client/client/bind_file.py | 45 | ||||
| -rw-r--r-- | pkb_client/client/client.py | 64 | ||||
| -rw-r--r-- | pkb_client/client/dns.py | 43 |
3 files changed, 120 insertions, 32 deletions
diff --git a/pkb_client/client/bind_file.py b/pkb_client/client/bind_file.py index af9abe0..17bb5ee 100644 --- a/pkb_client/client/bind_file.py +++ b/pkb_client/client/bind_file.py @@ -27,7 +27,7 @@ class BindRecord: record_string = f"{self.name} {self.ttl} {self.record_class} {self.record_type}" if self.prio is not None: record_string += f" {self.prio}" - record_string += f" {self.data}" + record_string += f' "{self.data}"' if self.comment: record_string += f" ; {self.comment}" return record_string @@ -68,18 +68,14 @@ class BindFile: # 2: name record-class ttl record-type record-data # whereby the ttl is optional - # drop any right trailing comments - line_parts = line.split(";", 1) - line = line_parts[0].strip() - comment = line_parts[1].strip() if len(line_parts) > 1 else None - prio = None + record_parts = line.strip().split() - # skip empty lines - if not line: + # skip comments + if not record_parts or record_parts[0].startswith(";"): continue - # find which format the line is - record_parts = line.split() + prio = None + if record_parts[1].isdigit(): # scheme 1 if record_parts[3] not in DNSRecordType.__members__: @@ -137,6 +133,35 @@ class BindFile: # replace @ in record name with origin record_name = record_name.replace("@", origin) + # handle comments and quoted strings as record data + comment = None + line = record_data.strip() + if line.startswith('"'): + # find rightmost double quote + rindex = line.rfind('"') + if rindex != -1: + # split at the last double quote + line_parts = line.rsplit('"', 1) + record_data = line_parts[0].strip('"') + + comment = line_parts[1].strip() if len(line_parts) > 1 else None + # left strip semicolon from comment + if comment and comment.startswith(";"): + comment = comment[1:].strip() + + if not comment: + comment = None + else: + record_data = line.strip('"') + else: + # try to split at the first semicolon for comments + if ";" in line: + record_data, comment = line.split(";", 1) + record_data = record_data.strip() + comment = comment.strip() + else: + record_data = line + records.append( BindRecord( record_name, diff --git a/pkb_client/client/client.py b/pkb_client/client/client.py index 81b946b..469751e 100644 --- a/pkb_client/client/client.py +++ b/pkb_client/client/client.py @@ -128,7 +128,7 @@ class PKBClient: req_json = { **self._get_auth_request_json(), "name": name, - "type": record_type, + "type": record_type.value, "content": content, "ttl": ttl, "prio": prio, @@ -182,7 +182,7 @@ class PKBClient: req_json = { **self._get_auth_request_json(), "name": name, - "type": record_type, + "type": record_type.value, "content": content, "ttl": ttl, "prio": prio, @@ -234,7 +234,7 @@ class PKBClient: ) req_json = { **self._get_auth_request_json(), - "type": record_type, + "type": record_type.value, "content": content, "ttl": ttl, "prio": prio, @@ -430,7 +430,7 @@ class PKBClient: logger.warning("file already exists, overwriting...") # domain header - bind_file_content = f"$ORIGIN {domain}" + bind_file_content = f"$ORIGIN {domain}." # SOA record soa_records = dns.resolver.resolve(domain, "SOA") @@ -441,7 +441,15 @@ class PKBClient: # records for record in dns_records: # name record class ttl record type record data - if record.prio: + # add trailing dot to the name if it is a supported record type, to make it a fully qualified domain name + if record.type in [ + DNSRecordType.MX, + DNSRecordType.CNAME, + DNSRecordType.NS, + DNSRecordType.SRV, + ]: + record.content += "." + if record.prio is not None: record_content = f"{record.prio} {record.content}" else: record_content = record.content @@ -498,7 +506,7 @@ class PKBClient: name = ".".join(exported_record["name"].split(".")[:-2]) self.create_dns_record( domain=domain, - record_type=exported_record["type"], + record_type=DNSRecordType(exported_record["type"]), content=exported_record["content"], name=name, ttl=exported_record["ttl"], @@ -534,7 +542,7 @@ class PKBClient: self.update_dns_record( domain=domain, record_id=existing_record.id, - record_type=record["type"], + record_type=DNSRecordType(record["type"]), content=record["content"], name=record["name"].replace(f".{domain}", ""), ttl=record["ttl"], @@ -564,7 +572,7 @@ class PKBClient: if existing_record is None: self.create_dns_record( domain=domain, - record_type=record["type"], + record_type=DNSRecordType(record["type"]), content=record["content"], name=record["name"].replace(f".{domain}", ""), ttl=record["ttl"], @@ -607,12 +615,20 @@ class PKBClient: for record in existing_dns_records: self.delete_dns_record(bind_file.origin[:-1], record.id) + nameserver_records = [] # restore all records from BIND file by creating new DNS records for record in bind_file.records: - # extract subdomain from record name - subdomain = record.name.replace(bind_file.origin, "") - # replace trailing dot - subdomain = subdomain[:-1] if subdomain.endswith(".") else subdomain + if record.record_type == DNSRecordType.NS: + # collect nameserver records to update them later in bulk + nameserver_records.append(record) + continue + if record.name.endswith("."): + # extract subdomain from record name, by removing the domain and TLD + subdomain = record.name.removesuffix(bind_file.origin) + subdomain = subdomain.removesuffix(".") + else: + subdomain = record.name + self.create_dns_record( domain=bind_file.origin[:-1], record_type=record.record_type, @@ -622,6 +638,17 @@ class PKBClient: prio=record.prio, ) + # update nameservers in bulk + if nameserver_records: + name_servers = [] + # remove trailing dot from nameserver records + for nameserver in nameserver_records: + if nameserver.data.endswith("."): + name_servers.append(nameserver.data[:-1]) + else: + name_servers.append(nameserver.data) + self.update_dns_servers(bind_file.origin[:-1], name_servers) + except Exception as e: logger.error("something went wrong: {}".format(e.__str__())) self.__handle_error_backup__(existing_dns_records) @@ -755,7 +782,7 @@ class PKBClient: **self._get_auth_request_json(), "subdomain": subdomain, "location": location, - "type": type, + "type": type.value, "includePath": include_path, "wildcard": wildcard, } @@ -950,11 +977,18 @@ class PKBClient: ) @staticmethod - def __handle_error_backup__(dns_records): + def __handle_error_backup__(dns_records: list[DNSRecord]) -> None: + """ + Handle errors when working with dns records by creating a backup of the given DNS records. + Crates a backup file in the current working directory with an incremental suffix. + + :param dns_records: the DNS records to backup + """ + # merge the single DNS records into one single dict with the record id as key dns_records_dict = dict() for record in dns_records: - dns_records_dict[record["id"]] = record + dns_records_dict[record.id] = record.to_dict() # generate filename with incremental suffix base_backup_filename = "pkb_client_dns_records_backup" diff --git a/pkb_client/client/dns.py b/pkb_client/client/dns.py index 85ae971..8bedad2 100644 --- a/pkb_client/client/dns.py +++ b/pkb_client/client/dns.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import Optional, Any class DNSRecordType(str, Enum): @@ -33,7 +33,14 @@ class DNSRecord: notes: str @staticmethod - def from_dict(d): + def from_dict(d: dict[str, Any]) -> "DNSRecord": + """ + Create a DNSRecord instance from a dictionary representation. + + :param d: Dictionary containing DNS record data. + :return: DNSRecord instance. + """ + # only use prio for supported record types since the API returns it for all records with default value 0 prio = int(d["prio"]) if d["type"] in DNS_RECORDS_WITH_PRIORITY else None return DNSRecord( @@ -46,6 +53,23 @@ class DNSRecord: notes=d["notes"], ) + def to_dict(self) -> dict[str, Any]: + """ + Convert the DNSRecord instance to a dictionary representation. + + :return: Dictionary containing DNS record data. + """ + + return { + "id": self.id, + "name": self.name, + "type": str(self.type), + "content": self.content, + "ttl": self.ttl, + "prio": self.prio, + "notes": self.notes, + } + class DNSRestoreMode(Enum): clear = 0 @@ -56,8 +80,13 @@ class DNSRestoreMode(Enum): return self.name @staticmethod - def from_string(a): - try: - return DNSRestoreMode[a] - except KeyError: - return a + def from_string(a: str) -> "DNSRestoreMode": + """ + Convert a string to a DNSRestoreMode enum member. + + :param a: String representation of the restore mode. + :return: Corresponding DNSRestoreMode enum member. + :raises KeyError: If the string does not match any enum member. + """ + + return DNSRestoreMode[a] |
