summaryrefslogtreecommitdiffstats
path: root/pkb_client/client/bind_file.py
blob: af9abe0a52562cb35c941b8809900ea64fdab40d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional, List

from pkb_client.client.dns import DNSRecordType, DNS_RECORDS_WITH_PRIORITY


class RecordClass(str, Enum):
    IN = "IN"

    def __str__(self):
        return self.value


@dataclass
class BindRecord:
    name: str
    ttl: int
    record_class: RecordClass
    record_type: DNSRecordType
    data: str
    prio: Optional[int] = None
    comment: Optional[str] = None

    def __str__(self):
        record_string = f"{self.name} {self.ttl} {self.record_class} {self.record_type}"
        if self.prio is not None:
            record_string += f" {self.prio}"
        record_string += f" {self.data}"
        if self.comment:
            record_string += f" ; {self.comment}"
        return record_string


class BindFile:
    origin: str
    ttl: Optional[int] = None
    records: List[BindRecord]

    def __init__(
        self,
        origin: str,
        ttl: Optional[int] = None,
        records: Optional[List[BindRecord]] = None,
    ) -> None:
        self.origin = origin
        self.ttl = ttl
        self.records = records or []

    @staticmethod
    def from_file(file_path: str) -> "BindFile":
        with open(file_path, "r") as f:
            file_data = f.readlines()

        # parse the file line by line
        origin = None
        ttl = None
        records = []
        for line in file_data:
            if line.startswith("$ORIGIN"):
                origin = line.split()[1]
            elif line.startswith("$TTL"):
                ttl = int(line.split()[1])
            else:
                # parse the records with the two possible formats:
                # 1: name 	ttl 	record-class 	record-type 	record-data
                # 2: name 	record-class 	ttl 	record-type 	record-data
                # whereby the ttl is optional

                # drop any right trailing comments
                line_parts = line.split(";", 1)
                line = line_parts[0].strip()
                comment = line_parts[1].strip() if len(line_parts) > 1 else None
                prio = None

                # skip empty lines
                if not line:
                    continue

                # find which format the line is
                record_parts = line.split()
                if record_parts[1].isdigit():
                    # scheme 1
                    if record_parts[3] not in DNSRecordType.__members__:
                        logging.warning(f"Ignoring unsupported record type: {line}")
                        continue
                    if record_parts[2] not in RecordClass.__members__:
                        logging.warning(f"Ignoring unsupported record class: {line}")
                        continue
                    record_name = record_parts[0]
                    record_ttl = int(record_parts[1])
                    record_class = RecordClass[record_parts[2]]
                    record_type = DNSRecordType[record_parts[3]]
                    if record_type in DNS_RECORDS_WITH_PRIORITY:
                        prio = int(record_parts[4])
                        record_data = " ".join(record_parts[5:])
                    else:
                        record_data = " ".join(record_parts[4:])
                elif record_parts[2].isdigit():
                    # scheme 2
                    if record_parts[3] not in DNSRecordType.__members__:
                        logging.warning(f"Ignoring unsupported record type: {line}")
                        continue
                    if record_parts[1] not in RecordClass.__members__:
                        logging.warning(f"Ignoring unsupported record class: {line}")
                        continue
                    record_name = record_parts[0]
                    record_ttl = int(record_parts[2])
                    record_class = RecordClass[record_parts[1]]
                    record_type = DNSRecordType[record_parts[3]]
                    if record_type in DNS_RECORDS_WITH_PRIORITY:
                        prio = int(record_parts[4])
                        record_data = " ".join(record_parts[5:])
                    else:
                        record_data = " ".join(record_parts[4:])
                else:
                    # no ttl, use default or previous
                    if record_parts[2] not in DNSRecordType.__members__:
                        logging.warning(f"Ignoring unsupported record type: {line}")
                        continue
                    if record_parts[1] not in RecordClass.__members__:
                        logging.warning(f"Ignoring unsupported record class: {line}")
                        continue
                    record_name = record_parts[0]
                    if ttl is None and not records:
                        raise ValueError("No TTL found in file")
                    record_ttl = ttl or records[-1].ttl
                    record_class = RecordClass[record_parts[1]]
                    record_type = DNSRecordType[record_parts[2]]
                    if record_type in DNS_RECORDS_WITH_PRIORITY:
                        prio = int(record_parts[3])
                        record_data = " ".join(record_parts[4:])
                    else:
                        record_data = " ".join(record_parts[3:])

                # replace @ in record name with origin
                record_name = record_name.replace("@", origin)

                records.append(
                    BindRecord(
                        record_name,
                        record_ttl,
                        record_class,
                        record_type,
                        record_data,
                        prio=prio,
                        comment=comment,
                    )
                )

        if origin is None:
            raise ValueError("No origin found in file")

        return BindFile(origin, ttl, records)

    def to_file(self, file_path: str) -> None:
        with open(file_path, "w") as f:
            f.write(str(self))

    def __str__(self) -> str:
        bind = f"$ORIGIN {self.origin}\n"

        if self.ttl is not None:
            bind += f"$TTL {self.ttl}\n"

        for record in self.records:
            bind += f"{record}\n"
        return bind