1 #!/usr/local/bin/python3
3 from ctypes import sizeof
5 from typing import List
6 from typing import NamedTuple
8 from atf_python.sys.netlink.attrs import NlAttr
9 from atf_python.sys.netlink.attrs import NlAttrNested
10 from atf_python.sys.netlink.base_headers import NlmAckFlags
11 from atf_python.sys.netlink.base_headers import NlmNewFlags
12 from atf_python.sys.netlink.base_headers import NlmGetFlags
13 from atf_python.sys.netlink.base_headers import NlmDeleteFlags
14 from atf_python.sys.netlink.base_headers import NlmBaseFlags
15 from atf_python.sys.netlink.base_headers import Nlmsghdr
16 from atf_python.sys.netlink.base_headers import NlMsgType
17 from atf_python.sys.netlink.utils import align4
18 from atf_python.sys.netlink.utils import enum_or_int
19 from atf_python.sys.netlink.utils import get_bitmask_str
22 class NlMsgCategory(Enum):
30 class NlMsgProps(NamedTuple):
32 category: NlMsgCategory
35 class BaseNetlinkMessage(object):
36 def __init__(self, helper, nlmsg_type):
37 self.nlmsg_type = enum_or_int(nlmsg_type)
39 self._orig_data = None
41 self.nl_hdr = Nlmsghdr(
42 nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid
46 def set_request(self, need_ack=True):
47 self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST])
49 self.add_nlflags([NlmBaseFlags.NLM_F_ACK])
51 def add_nlflags(self, flags: List):
54 int_flags |= enum_or_int(flag)
55 self.nl_hdr.nlmsg_flags |= int_flags
57 def add_nla(self, nla):
58 self.nla_list.append(nla)
60 def _get_nla(self, nla_list, nla_type):
61 nla_type_raw = enum_or_int(nla_type)
63 if nla.nla_type == nla_type_raw:
67 def get_nla(self, nla_type):
68 return self._get_nla(self.nla_list, nla_type)
71 def parse_nl_header(data: bytes):
72 if len(data) < sizeof(Nlmsghdr):
73 raise ValueError("length less than netlink message header")
74 return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr)
76 def is_type(self, nlmsg_type):
77 nlmsg_type_raw = enum_or_int(nlmsg_type)
78 return nlmsg_type_raw == self.nl_hdr.nlmsg_type
80 def is_reply(self, hdr):
81 return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value
85 return "msg#{}".format(self._get_msg_type())
87 def _get_nl_category(self):
88 if self.is_reply(self.nl_hdr):
89 return NlMsgCategory.ACK
90 return NlMsgCategory.UNKNOWN
92 def get_nlm_flags_str(self):
93 category = self._get_nl_category()
94 flags = self.nl_hdr.nlmsg_flags
96 if category == NlMsgCategory.UNKNOWN:
97 return self.helper.get_bitmask_str(NlmBaseFlags, flags)
98 elif category == NlMsgCategory.GET:
99 flags_enum = NlmGetFlags
100 elif category == NlMsgCategory.NEW:
101 flags_enum = NlmNewFlags
102 elif category == NlMsgCategory.DELETE:
103 flags_enum = NlmDeleteFlags
104 elif category == NlMsgCategory.ACK:
105 flags_enum = NlmAckFlags
106 return get_bitmask_str([NlmBaseFlags, flags_enum], flags)
108 def print_nl_header(self, prepend=""):
109 # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501
112 "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format(
116 self.get_nlm_flags_str(),
124 def from_bytes(cls, helper, data):
126 hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data)
127 self = cls(helper, hdr.nlmsg_type)
128 self._orig_data = data
130 except ValueError as e:
131 print("Failed to parse nl header: {}".format(e))
132 cls.print_as_bytes(data)
136 def print_message(self):
137 self.print_nl_header()
140 def print_as_bytes(data: bytes, descr: str):
141 print("===vv {} (len:{:3d}) vv===".format(descr, len(data)))
144 while off < len(data):
145 for i in range(step):
146 if off + i < len(data):
147 print(" {:02X}".format(data[off + i]), end="")
150 print("--------------------")
153 class StdNetlinkMessage(BaseNetlinkMessage):
157 def from_bytes(cls, helper, data):
159 hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data)
160 self = cls(helper, hdr.nlmsg_type)
161 self._orig_data = data
163 except ValueError as e:
164 print("Failed to parse nl header: {}".format(e))
165 cls.print_as_bytes(data)
168 offset = align4(hdrlen)
170 base_hdr, hdrlen = self.parse_base_header(data[offset:])
171 self.base_hdr = base_hdr
172 offset += align4(hdrlen)
174 except ValueError as e:
175 print("Failed to parse nl rt header: {}".format(e))
176 cls.print_as_bytes(data)
181 nla_list, nla_len = self.parse_nla_list(data[offset:])
183 if offset != len(data):
185 "{} bytes left at the end of the packet".format(len(data) - offset)
187 self.nla_list = nla_list
188 except ValueError as e:
190 "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e)
192 cls.print_as_bytes(data, "msg dump")
193 cls.print_as_bytes(data[orig_offset:], "failed block")
197 def parse_attrs(self, data: bytes, attr_map):
200 while len(data) - off >= 4:
201 nla_len, raw_nla_type = struct.unpack("@HH", data[off:off + 4])
202 if nla_len + off > len(data):
204 "attr length {} > than the remaining length {}".format(
205 nla_len, len(data) - off
208 nla_type = raw_nla_type & 0x3F
209 if nla_type in attr_map:
210 v = attr_map[nla_type]
211 val = v["ad"].cls.from_bytes(data[off:off + nla_len], v["ad"].val)
214 attrs, _ = self.parse_attrs(
215 data[off + 4:off + nla_len], v["child"]
217 val = NlAttrNested(v["ad"].val, attrs)
220 val = NlAttr(raw_nla_type, data[off + 4:off + nla_len])
222 off += align4(nla_len)
225 def parse_nla_list(self, data: bytes) -> List[NlAttr]:
226 return self.parse_attrs(data, self.nl_attrs_map)
230 for nla in self.nla_list:
232 ret = bytes(self.base_hdr) + ret
233 self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr)
234 return bytes(self.nl_hdr) + ret
236 def _get_msg_type(self):
237 return self.nl_hdr.nlmsg_type
241 msg_type = self._get_msg_type()
242 for msg_props in self.messages:
243 if msg_props.msg.value == msg_type:
249 msg_props = self.msg_props
250 if msg_props is not None:
251 return msg_props.msg.name
252 return super().msg_name
254 def print_base_header(self, hdr, prepend=""):
257 def print_message(self):
258 self.print_nl_header()
259 self.print_base_header(self.base_hdr, " ")
260 for nla in self.nla_list: