]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - tests/atf_python/sys/netlink/message.py
tests: add support for parsing generic netlink families.
[FreeBSD/FreeBSD.git] / tests / atf_python / sys / netlink / message.py
1 #!/usr/local/bin/python3
2 import struct
3 from ctypes import sizeof
4 from enum import Enum
5 from typing import List
6 from typing import NamedTuple
7
8 from atf_python.sys.netlink.attrs import NlAttr
9 from atf_python.sys.netlink.attrs import NlAttrNested
10 from atf_python.sys.netlink.base_headers import NlmAckFlags
11 from atf_python.sys.netlink.base_headers import NlmNewFlags
12 from atf_python.sys.netlink.base_headers import NlmGetFlags
13 from atf_python.sys.netlink.base_headers import NlmDeleteFlags
14 from atf_python.sys.netlink.base_headers import NlmBaseFlags
15 from atf_python.sys.netlink.base_headers import Nlmsghdr
16 from atf_python.sys.netlink.base_headers import NlMsgType
17 from atf_python.sys.netlink.utils import align4
18 from atf_python.sys.netlink.utils import enum_or_int
19 from atf_python.sys.netlink.utils import get_bitmask_str
20
21
22 class NlMsgCategory(Enum):
23     UNKNOWN = 0
24     GET = 1
25     NEW = 2
26     DELETE = 3
27     ACK = 4
28
29
30 class NlMsgProps(NamedTuple):
31     msg: Enum
32     category: NlMsgCategory
33
34
35 class BaseNetlinkMessage(object):
36     def __init__(self, helper, nlmsg_type):
37         self.nlmsg_type = enum_or_int(nlmsg_type)
38         self.nla_list = []
39         self._orig_data = None
40         self.helper = helper
41         self.nl_hdr = Nlmsghdr(
42             nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid
43         )
44         self.base_hdr = None
45
46     def set_request(self, need_ack=True):
47         self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST])
48         if need_ack:
49             self.add_nlflags([NlmBaseFlags.NLM_F_ACK])
50
51     def add_nlflags(self, flags: List):
52         int_flags = 0
53         for flag in flags:
54             int_flags |= enum_or_int(flag)
55         self.nl_hdr.nlmsg_flags |= int_flags
56
57     def add_nla(self, nla):
58         self.nla_list.append(nla)
59
60     def _get_nla(self, nla_list, nla_type):
61         nla_type_raw = enum_or_int(nla_type)
62         for nla in nla_list:
63             if nla.nla_type == nla_type_raw:
64                 return nla
65         return None
66
67     def get_nla(self, nla_type):
68         return self._get_nla(self.nla_list, nla_type)
69
70     @staticmethod
71     def parse_nl_header(data: bytes):
72         if len(data) < sizeof(Nlmsghdr):
73             raise ValueError("length less than netlink message header")
74         return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr)
75
76     def is_type(self, nlmsg_type):
77         nlmsg_type_raw = enum_or_int(nlmsg_type)
78         return nlmsg_type_raw == self.nl_hdr.nlmsg_type
79
80     def is_reply(self, hdr):
81         return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value
82
83     @property
84     def msg_name(self):
85         return "msg#{}".format(self._get_msg_type())
86
87     def _get_nl_category(self):
88         if self.is_reply(self.nl_hdr):
89             return NlMsgCategory.ACK
90         return NlMsgCategory.UNKNOWN
91
92     def get_nlm_flags_str(self):
93         category = self._get_nl_category()
94         flags = self.nl_hdr.nlmsg_flags
95
96         if category == NlMsgCategory.UNKNOWN:
97             return self.helper.get_bitmask_str(NlmBaseFlags, flags)
98         elif category == NlMsgCategory.GET:
99             flags_enum = NlmGetFlags
100         elif category == NlMsgCategory.NEW:
101             flags_enum = NlmNewFlags
102         elif category == NlMsgCategory.DELETE:
103             flags_enum = NlmDeleteFlags
104         elif category == NlMsgCategory.ACK:
105             flags_enum = NlmAckFlags
106         return get_bitmask_str([NlmBaseFlags, flags_enum], flags)
107
108     def print_nl_header(self, prepend=""):
109         # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0  # noqa: E501
110         hdr = self.nl_hdr
111         print(
112             "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format(
113                 prepend,
114                 hdr.nlmsg_len,
115                 self.msg_name,
116                 self.get_nlm_flags_str(),
117                 hdr.nlmsg_flags,
118                 hdr.nlmsg_seq,
119                 hdr.nlmsg_pid,
120             )
121         )
122
123     @classmethod
124     def from_bytes(cls, helper, data):
125         try:
126             hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data)
127             self = cls(helper, hdr.nlmsg_type)
128             self._orig_data = data
129             self.nl_hdr = hdr
130         except ValueError as e:
131             print("Failed to parse nl header: {}".format(e))
132             cls.print_as_bytes(data)
133             raise
134         return self
135
136     def print_message(self):
137         self.print_nl_header()
138
139     @staticmethod
140     def print_as_bytes(data: bytes, descr: str):
141         print("===vv {} (len:{:3d}) vv===".format(descr, len(data)))
142         off = 0
143         step = 16
144         while off < len(data):
145             for i in range(step):
146                 if off + i < len(data):
147                     print(" {:02X}".format(data[off + i]), end="")
148             print("")
149             off += step
150         print("--------------------")
151
152
153 class StdNetlinkMessage(BaseNetlinkMessage):
154     nl_attrs_map = {}
155
156     @classmethod
157     def from_bytes(cls, helper, data):
158         try:
159             hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data)
160             self = cls(helper, hdr.nlmsg_type)
161             self._orig_data = data
162             self.nl_hdr = hdr
163         except ValueError as e:
164             print("Failed to parse nl header: {}".format(e))
165             cls.print_as_bytes(data)
166             raise
167
168         offset = align4(hdrlen)
169         try:
170             base_hdr, hdrlen = self.parse_base_header(data[offset:])
171             self.base_hdr = base_hdr
172             offset += align4(hdrlen)
173             # XXX: CAP_ACK
174         except ValueError as e:
175             print("Failed to parse nl rt header: {}".format(e))
176             cls.print_as_bytes(data)
177             raise
178
179         orig_offset = offset
180         try:
181             nla_list, nla_len = self.parse_nla_list(data[offset:])
182             offset += nla_len
183             if offset != len(data):
184                 raise ValueError(
185                     "{} bytes left at the end of the packet".format(len(data) - offset)
186                 )  # noqa: E501
187             self.nla_list = nla_list
188         except ValueError as e:
189             print(
190                 "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e)
191             )  # noqa: E501
192             cls.print_as_bytes(data, "msg dump")
193             cls.print_as_bytes(data[orig_offset:], "failed block")
194             raise
195         return self
196
197     def parse_attrs(self, data: bytes, attr_map):
198         ret = []
199         off = 0
200         while len(data) - off >= 4:
201             nla_len, raw_nla_type = struct.unpack("@HH", data[off:off + 4])
202             if nla_len + off > len(data):
203                 raise ValueError(
204                     "attr length {} > than the remaining length {}".format(
205                         nla_len, len(data) - off
206                     )
207                 )
208             nla_type = raw_nla_type & 0x3F
209             if nla_type in attr_map:
210                 v = attr_map[nla_type]
211                 val = v["ad"].cls.from_bytes(data[off:off + nla_len], v["ad"].val)
212                 if "child" in v:
213                     # nested
214                     attrs, _ = self.parse_attrs(
215                         data[off + 4:off + nla_len], v["child"]
216                     )
217                     val = NlAttrNested(v["ad"].val, attrs)
218             else:
219                 # unknown attribute
220                 val = NlAttr(raw_nla_type, data[off + 4:off + nla_len])
221             ret.append(val)
222             off += align4(nla_len)
223         return ret, off
224
225     def parse_nla_list(self, data: bytes) -> List[NlAttr]:
226         return self.parse_attrs(data, self.nl_attrs_map)
227
228     def __bytes__(self):
229         ret = bytes()
230         for nla in self.nla_list:
231             ret += bytes(nla)
232         ret = bytes(self.base_hdr) + ret
233         self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr)
234         return bytes(self.nl_hdr) + ret
235
236     def _get_msg_type(self):
237         return self.nl_hdr.nlmsg_type
238
239     @property
240     def msg_props(self):
241         msg_type = self._get_msg_type()
242         for msg_props in self.messages:
243             if msg_props.msg.value == msg_type:
244                 return msg_props
245         return None
246
247     @property
248     def msg_name(self):
249         msg_props = self.msg_props
250         if msg_props is not None:
251             return msg_props.msg.name
252         return super().msg_name
253
254     def print_base_header(self, hdr, prepend=""):
255         pass
256
257     def print_message(self):
258         self.print_nl_header()
259         self.print_base_header(self.base_hdr, " ")
260         for nla in self.nla_list:
261             nla.print_attr("  ")