simplify some code

This commit is contained in:
Ivan Schaller 2024-02-28 14:07:30 +01:00
parent e27f8938a5
commit 792b9d5429

View file

@ -11,9 +11,7 @@ import pynetbox.core.response
class NetBoxDNSSource(octodns.provider.base.BaseProvider): class NetBoxDNSSource(octodns.provider.base.BaseProvider):
""" """OctoDNS provider for NetboxDNS"""
OctoDNS provider for NetboxDNS
"""
SUPPORTS_GEO = False SUPPORTS_GEO = False
SUPPORTS_DYNAMIC = False SUPPORTS_DYNAMIC = False
@ -62,9 +60,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
*args, *args,
**kwargs, **kwargs,
): ):
""" """initialize the NetboxDNSSource"""
Initialize the NetboxDNSSource
"""
self.log = logging.getLogger(f"NetboxDNSSource[{id}]") self.log = logging.getLogger(f"NetboxDNSSource[{id}]")
self.log.debug(f"__init__: {id=}, {url=}, {view=}, {replace_duplicates=}, {make_absolute=}") self.log.debug(f"__init__: {id=}, {url=}, {view=}, {replace_duplicates=}, {make_absolute=}")
super().__init__(id, *args, **kwargs) super().__init__(id, *args, **kwargs)
@ -76,8 +72,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
self.make_absolute = make_absolute self.make_absolute = make_absolute
def _make_absolute(self, value: str) -> str: def _make_absolute(self, value: str) -> str:
""" """return dns name with trailing dot to make it absolute
Return dns name with trailing dot to make it absolute
@param value: dns record value @param value: dns record value
@ -92,8 +87,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
return absolute_value return absolute_value
def _get_nb_view(self, view: str | None | Literal[False]) -> dict[str, int | str]: def _get_nb_view(self, view: str | None | Literal[False]) -> dict[str, int | str]:
""" """get the correct netbox view when requested
Get the correct netbox view when requested
@param view: `False` for no view, `None` for zones without a view, else the view name @param view: `False` for no view, `None` for zones without a view, else the view name
@ -115,8 +109,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
return {"view_id": nb_view.id} return {"view_id": nb_view.id}
def _get_nb_zone(self, name: str, view: dict[str, str | int]) -> pynetbox.core.response.Record: def _get_nb_zone(self, name: str, view: dict[str, str | int]) -> pynetbox.core.response.Record:
""" """given a zone name and a view name, look it up in NetBox.
Given a zone name and a view name, look it up in NetBox.
@param name: name of the dns zone @param name: name of the dns zone
@param view: the netbox view id in the api query format @param view: the netbox view id in the api query format
@ -133,8 +126,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
return nb_zone return nb_zone
def _format_rdata(self, rdata: dns.rdata.Rdata, raw_value: str) -> str | dict[str, Any]: def _format_rdata(self, rdata: dns.rdata.Rdata, raw_value: str) -> str | dict[str, Any]:
""" """format netbox record values to correct octodns record values
Format netbox record values to correct octodns record values
@param rdata: rrdata record value @param rdata: rrdata record value
@param raw_value: raw record value @param raw_value: raw record value
@ -221,8 +213,7 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
return value # type:ignore return value # type:ignore
def _format_nb_records(self, zone: octodns.zone.Zone) -> list[dict[str, Any]]: def _format_nb_records(self, zone: octodns.zone.Zone) -> list[dict[str, Any]]:
""" """format netbox dns records to the octodns format
Format netbox dns records to the octodns format
@param zone: octodns zone @param zone: octodns zone
@ -273,12 +264,13 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
def populate( def populate(
self, zone: octodns.zone.Zone, target: bool = False, lenient: bool = False self, zone: octodns.zone.Zone, target: bool = False, lenient: bool = False
) -> bool: ) -> bool:
""" """get all the records of a zone from NetBox and add them to the OctoDNS zone
Get all the records of a zone from NetBox and add them to the OctoDNS zone
@param zone: octodns zone @param zone: octodns zone
@param target: when `True`, load the current state of the provider. @param target: when `True`, load the current state of the provider.
@param lenient: when `True`, skip record validation and do a "best effort" load of data. @param lenient: when `True`, skip record validation and do a "best effort" load of data.
@return: true if the zone exists, else false.
""" """
self.log.info(f"populate -> '{zone.name}', target={target}, lenient={lenient}") self.log.info(f"populate -> '{zone.name}', target={target}, lenient={lenient}")
@ -306,15 +298,38 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
return True return True
def _include_change(self, change: octodns.record.change.Change) -> bool: @staticmethod
"""Filter out record types which the provider can't create in netbox""" def __format_changeset(change: Any) -> set[str]:
"""format the changeset
@param change: the raw changes
@return: the formatted changeset
"""
match change:
case octodns.record.ValueMixin():
changeset = {repr(change.value)[1:-1]}
case octodns.record.ValuesMixin():
changeset = {repr(v)[1:-1] for v in change.values}
case _:
raise ValueError
return changeset
@staticmethod
def _include_change(change: octodns.record.change.Change) -> bool:
"""filter out record types which the provider can't create in netbox
@param change: the planned change
@return: false if the change should be discarded, true if it should be kept.
"""
if change.new._type in ["SOA", "PTR", "NS"]: if change.new._type in ["SOA", "PTR", "NS"]:
return False return False
return True return True
def _apply(self, plan: octodns.provider.plan.Plan) -> None: def _apply(self, plan: octodns.provider.plan.Plan) -> None:
"""Apply the changes to the NetBox DNS zone.""" """apply the changes to the NetBox DNS zone."""
self.log.debug(f"_apply: zone={plan.desired.name}, len(changes)={len(plan.changes)}") self.log.debug(f"_apply: zone={plan.desired.name}, len(changes)={len(plan.changes)}")
nb_zone = self._get_nb_zone(plan.desired.name, view=self.nb_view) nb_zone = self._get_nb_zone(plan.desired.name, view=self.nb_view)
@ -322,20 +337,12 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
for change in plan.changes: for change in plan.changes:
match change: match change:
case octodns.record.Create(): case octodns.record.Create():
name = change.new.name name = "@" if change.name.name == "" else change.name.name
if name == "":
name = "@"
match change.new:
case octodns.record.ValueMixin():
new = {repr(change.new.value)[1:-1]}
case octodns.record.ValuesMixin():
new = {repr(v)[1:-1] for v in change.new.values}
case _:
raise ValueError
new = self.__format_changeset(change.new)
for value in new: for value in new:
nb_record = self.api.plugins.netbox_dns.records.create( nb_record: pynetbox.core.response.Record = (
self.api.plugins.netbox_dns.records.create(
zone=nb_zone.id, zone=nb_zone.id,
name=name, name=name,
type=change.new._type, type=change.new._type,
@ -343,27 +350,21 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
value=value.replace("\\\\", "\\").replace("\\;", ";"), value=value.replace("\\\\", "\\").replace("\\;", ";"),
disable_ptr=True, disable_ptr=True,
) )
)
self.log.debug(f"{nb_record!r}") self.log.debug(f"{nb_record!r}")
case octodns.record.Delete(): case octodns.record.Delete():
name = change.existing.name name = "@" if change.existing.name == "" else change.existing.name
if name == "":
name = "@"
nb_records = self.api.plugins.netbox_dns.records.filter( nb_records: pynetbox.core.response.RecordSet = (
self.api.plugins.netbox_dns.records.filter(
zone_id=nb_zone.id, zone_id=nb_zone.id,
name=change.existing.name, name=change.existing.name,
type=change.existing._type, type=change.existing._type,
) )
)
match change.existing: existing = self.__format_changeset(change.existing)
case octodns.record.ValueMixin():
existing = {repr(change.existing.value)[1:-1]}
case octodns.record.ValuesMixin():
existing = {repr(v)[1:-1] for v in change.existing.values}
case _:
raise ValueError
for nb_record in nb_records: for nb_record in nb_records:
for value in existing: for value in existing:
if nb_record.value == value: if nb_record.value == value:
@ -374,31 +375,18 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
nb_record.delete() nb_record.delete()
case octodns.record.Update(): case octodns.record.Update():
name = change.existing.name name = "@" if change.existing.name == "" else change.existing.name
if name == "":
name = "@"
nb_records = self.api.plugins.netbox_dns.records.filter( nb_records: pynetbox.core.response.RecordSet = (
self.api.plugins.netbox_dns.records.filter(
zone_id=nb_zone.id, zone_id=nb_zone.id,
name=name, name=name,
type=change.existing._type, type=change.existing._type,
) )
)
match change.existing: existing = self.__format_changeset(change.existing)
case octodns.record.ValueMixin(): new = self.__format_changeset(change.new)
existing = {repr(change.existing.value)[1:-1]}
case octodns.record.ValuesMixin():
existing = {repr(v)[1:-1] for v in change.existing.values}
case _:
raise ValueError
match change.new:
case octodns.record.ValueMixin():
new = {repr(change.new.value)[1:-1]}
case octodns.record.ValuesMixin():
new = {repr(v)[1:-1] for v in change.new.values}
case _:
raise ValueError
delete = existing.difference(new) delete = existing.difference(new)
update = existing.intersection(new) update = existing.intersection(new)
@ -412,7 +400,8 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
nb_record.save() nb_record.save()
for value in create: for value in create:
nb_record = self.api.plugins.netbox_dns.records.create( nb_record: pynetbox.core.response.Record = (
self.api.plugins.netbox_dns.records.create(
zone=nb_zone.id, zone=nb_zone.id,
name=name, name=name,
type=change.new._type, type=change.new._type,
@ -420,4 +409,5 @@ class NetBoxDNSSource(octodns.provider.base.BaseProvider):
value=value.replace("\\\\", "\\").replace("\\;", ";"), value=value.replace("\\\\", "\\").replace("\\;", ";"),
disable_ptr=True, disable_ptr=True,
) )
)
self.log.debug(f"{nb_record!r}") self.log.debug(f"{nb_record!r}")