diff --git a/src/octodns_netbox_dns/__init__.py b/src/octodns_netbox_dns/__init__.py index 9f99416..f1d1b27 100644 --- a/src/octodns_netbox_dns/__init__.py +++ b/src/octodns_netbox_dns/__init__.py @@ -11,9 +11,7 @@ import pynetbox.core.response class NetBoxDNSSource(octodns.provider.base.BaseProvider): - """ - OctoDNS provider for NetboxDNS - """ + """OctoDNS provider for NetboxDNS""" SUPPORTS_GEO = False SUPPORTS_DYNAMIC = False @@ -62,9 +60,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): *args, **kwargs, ): - """ - Initialize the NetboxDNSSource - """ + """initialize the NetboxDNSSource""" self.log = logging.getLogger(f"NetboxDNSSource[{id}]") self.log.debug(f"__init__: {id=}, {url=}, {view=}, {replace_duplicates=}, {make_absolute=}") super().__init__(id, *args, **kwargs) @@ -76,8 +72,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): self.make_absolute = make_absolute def _make_absolute(self, value: str) -> str: - """ - Return dns name with trailing dot to make it absolute + """return dns name with trailing dot to make it absolute @param value: dns record value @@ -92,8 +87,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): return absolute_value def _get_nb_view(self, view: str | None | Literal[False]) -> dict[str, int | str]: - """ - Get the correct netbox view when requested + """get the correct netbox view when requested @param view: `False` for no view, `None` for zones without a view, else the view name @@ -115,8 +109,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): return {"view_id": nb_view.id} def _get_nb_zone(self, name: str, view: dict[str, str | int]) -> pynetbox.core.response.Record: - """ - Given a zone name and a view name, look it up in NetBox. + """given a zone name and a view name, look it up in NetBox. @param name: name of the dns zone @param view: the netbox view id in the api query format @@ -133,8 +126,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): return nb_zone def _format_rdata(self, rdata: dns.rdata.Rdata, raw_value: str) -> str | dict[str, Any]: - """ - Format netbox record values to correct octodns record values + """format netbox record values to correct octodns record values @param rdata: rrdata record value @param raw_value: raw record value @@ -221,8 +213,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): return value # type:ignore def _format_nb_records(self, zone: octodns.zone.Zone) -> list[dict[str, Any]]: - """ - Format netbox dns records to the octodns format + """format netbox dns records to the octodns format @param zone: octodns zone @@ -273,12 +264,13 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): def populate( self, zone: octodns.zone.Zone, target: bool = False, lenient: bool = False ) -> bool: - """ - Get all the records of a zone from NetBox and add them to the OctoDNS zone + """get all the records of a zone from NetBox and add them to the OctoDNS zone @param zone: octodns zone @param target: when `True`, load the current state of the provider. @param lenient: when `True`, skip record validation and do a "best effort" load of data. + + @return: true if the zone exists, else false. """ self.log.info(f"populate -> '{zone.name}', target={target}, lenient={lenient}") @@ -306,15 +298,38 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): return True - def _include_change(self, change: octodns.record.change.Change) -> bool: - """Filter out record types which the provider can't create in netbox""" + @staticmethod + def __format_changeset(change: Any) -> set[str]: + """format the changeset + + @param change: the raw changes + + @return: the formatted changeset + """ + match change: + case octodns.record.ValueMixin(): + changeset = {repr(change.value)[1:-1]} + case octodns.record.ValuesMixin(): + changeset = {repr(v)[1:-1] for v in change.values} + case _: + raise ValueError + + return changeset + + @staticmethod + def _include_change(change: octodns.record.change.Change) -> bool: + """filter out record types which the provider can't create in netbox + @param change: the planned change + + @return: false if the change should be discarded, true if it should be kept. + """ if change.new._type in ["SOA", "PTR", "NS"]: return False return True def _apply(self, plan: octodns.provider.plan.Plan) -> None: - """Apply the changes to the NetBox DNS zone.""" + """apply the changes to the NetBox DNS zone.""" self.log.debug(f"_apply: zone={plan.desired.name}, len(changes)={len(plan.changes)}") nb_zone = self._get_nb_zone(plan.desired.name, view=self.nb_view) @@ -322,48 +337,34 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): for change in plan.changes: match change: case octodns.record.Create(): - name = change.new.name - if name == "": - name = "@" - - match change.new: - case octodns.record.ValueMixin(): - new = {repr(change.new.value)[1:-1]} - case octodns.record.ValuesMixin(): - new = {repr(v)[1:-1] for v in change.new.values} - case _: - raise ValueError + name = "@" if change.name.name == "" else change.name.name + new = self.__format_changeset(change.new) for value in new: - nb_record = self.api.plugins.netbox_dns.records.create( - zone=nb_zone.id, - name=name, - type=change.new._type, - ttl=change.new.ttl, - value=value.replace("\\\\", "\\").replace("\\;", ";"), - disable_ptr=True, + nb_record: pynetbox.core.response.Record = ( + self.api.plugins.netbox_dns.records.create( + zone=nb_zone.id, + name=name, + type=change.new._type, + ttl=change.new.ttl, + value=value.replace("\\\\", "\\").replace("\\;", ";"), + disable_ptr=True, + ) ) self.log.debug(f"{nb_record!r}") case octodns.record.Delete(): - name = change.existing.name - if name == "": - name = "@" + name = "@" if change.existing.name == "" else change.existing.name - nb_records = self.api.plugins.netbox_dns.records.filter( - zone_id=nb_zone.id, - name=change.existing.name, - type=change.existing._type, + nb_records: pynetbox.core.response.RecordSet = ( + self.api.plugins.netbox_dns.records.filter( + zone_id=nb_zone.id, + name=change.existing.name, + type=change.existing._type, + ) ) - match change.existing: - case octodns.record.ValueMixin(): - existing = {repr(change.existing.value)[1:-1]} - case octodns.record.ValuesMixin(): - existing = {repr(v)[1:-1] for v in change.existing.values} - case _: - raise ValueError - + existing = self.__format_changeset(change.existing) for nb_record in nb_records: for value in existing: if nb_record.value == value: @@ -374,31 +375,18 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): nb_record.delete() case octodns.record.Update(): - name = change.existing.name - if name == "": - name = "@" + name = "@" if change.existing.name == "" else change.existing.name - nb_records = self.api.plugins.netbox_dns.records.filter( - zone_id=nb_zone.id, - name=name, - type=change.existing._type, + nb_records: pynetbox.core.response.RecordSet = ( + self.api.plugins.netbox_dns.records.filter( + zone_id=nb_zone.id, + name=name, + type=change.existing._type, + ) ) - match change.existing: - case octodns.record.ValueMixin(): - existing = {repr(change.existing.value)[1:-1]} - case octodns.record.ValuesMixin(): - existing = {repr(v)[1:-1] for v in change.existing.values} - case _: - raise ValueError - - match change.new: - case octodns.record.ValueMixin(): - new = {repr(change.new.value)[1:-1]} - case octodns.record.ValuesMixin(): - new = {repr(v)[1:-1] for v in change.new.values} - case _: - raise ValueError + existing = self.__format_changeset(change.existing) + new = self.__format_changeset(change.new) delete = existing.difference(new) update = existing.intersection(new) @@ -412,12 +400,14 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider): nb_record.save() for value in create: - nb_record = self.api.plugins.netbox_dns.records.create( - zone=nb_zone.id, - name=name, - type=change.new._type, - ttl=change.new.ttl, - value=value.replace("\\\\", "\\").replace("\\;", ";"), - disable_ptr=True, + nb_record: pynetbox.core.response.Record = ( + self.api.plugins.netbox_dns.records.create( + zone=nb_zone.id, + name=name, + type=change.new._type, + ttl=change.new.ttl, + value=value.replace("\\\\", "\\").replace("\\;", ";"), + disable_ptr=True, + ) ) self.log.debug(f"{nb_record!r}")