]> CyberLeo.Net >> Repos - FreeBSD/FreeBSD.git/blob - tests/atf_python/atf_pytest.py
atf_python: Standardize custom sections
[FreeBSD/FreeBSD.git] / tests / atf_python / atf_pytest.py
1 import types
2 from typing import Any
3 from typing import Dict
4 from typing import List
5 from typing import NamedTuple
6 from typing import Optional
7 from typing import Tuple
8
9 from atf_python.ktest import generate_ktests
10 from atf_python.utils import nodeid_to_method_name
11
12 import pytest
13 import os
14
15
16 class ATFCleanupItem(pytest.Item):
17     def runtest(self):
18         """Runs cleanup procedure for the test instead of the test itself"""
19         instance = self.parent.cls()
20         cleanup_name = "cleanup_{}".format(nodeid_to_method_name(self.nodeid))
21         if hasattr(instance, cleanup_name):
22             cleanup = getattr(instance, cleanup_name)
23             cleanup(self.nodeid)
24         elif hasattr(instance, "cleanup"):
25             instance.cleanup(self.nodeid)
26
27     def setup_method_noop(self, method):
28         """Overrides runtest setup method"""
29         pass
30
31     def teardown_method_noop(self, method):
32         """Overrides runtest teardown method"""
33         pass
34
35
36 class ATFTestObj(object):
37     def __init__(self, obj, has_cleanup):
38         # Use nodeid without name to properly name class-derived tests
39         self.ident = obj.nodeid.split("::", 1)[1]
40         self.description = self._get_test_description(obj)
41         self.has_cleanup = has_cleanup
42         self.obj = obj
43
44     def _get_test_description(self, obj):
45         """Returns first non-empty line from func docstring or func name"""
46         if getattr(obj, "descr", None) is not None:
47             return getattr(obj, "descr")
48         docstr = obj.function.__doc__
49         if docstr:
50             for line in docstr.split("\n"):
51                 if line:
52                     return line
53         return obj.name
54
55     @staticmethod
56     def _convert_user_mark(mark, obj, ret: Dict):
57         username = mark.args[0]
58         if username == "unprivileged":
59             # Special unprivileged user requested.
60             # First, require the unprivileged-user config option presence
61             key = "require.config"
62             if key not in ret:
63                 ret[key] = "unprivileged_user"
64             else:
65                 ret[key] = "{} {}".format(ret[key], "unprivileged_user")
66         # Check if the framework requires root
67         test_cls = ATFHandler.get_test_class(obj)
68         if test_cls and getattr(test_cls, "NEED_ROOT", False):
69             # Yes, so we ask kyua to run us under root instead
70             # It is up to the implementation to switch back to the desired
71             # user
72             ret["require.user"] = "root"
73         else:
74             ret["require.user"] = username
75
76     def _convert_marks(self, obj) -> Dict[str, Any]:
77         wj_func = lambda x: " ".join(x)  # noqa: E731
78         _map: Dict[str, Dict] = {
79             "require_user": {"handler": self._convert_user_mark},
80             "require_arch": {"name": "require.arch", "fmt": wj_func},
81             "require_diskspace": {"name": "require.diskspace"},
82             "require_files": {"name": "require.files", "fmt": wj_func},
83             "require_machine": {"name": "require.machine", "fmt": wj_func},
84             "require_memory": {"name": "require.memory"},
85             "require_progs": {"name": "require.progs", "fmt": wj_func},
86             "timeout": {},
87         }
88         ret = {}
89         for mark in obj.iter_markers():
90             if mark.name in _map:
91                 if "handler" in _map[mark.name]:
92                     _map[mark.name]["handler"](mark, obj, ret)
93                     continue
94                 name = _map[mark.name].get("name", mark.name)
95                 if "fmt" in _map[mark.name]:
96                     val = _map[mark.name]["fmt"](mark.args[0])
97                 else:
98                     val = mark.args[0]
99                 ret[name] = val
100         return ret
101
102     def as_lines(self) -> List[str]:
103         """Output test definition in ATF-specific format"""
104         ret = []
105         ret.append("ident: {}".format(self.ident))
106         ret.append("descr: {}".format(self._get_test_description(self.obj)))
107         if self.has_cleanup:
108             ret.append("has.cleanup: true")
109         for key, value in self._convert_marks(self.obj).items():
110             ret.append("{}: {}".format(key, value))
111         return ret
112
113
114 class ATFHandler(object):
115     class ReportState(NamedTuple):
116         state: str
117         reason: str
118
119     def __init__(self, report_file_name: Optional[str]):
120         self._tests_state_map: Dict[str, ReportStatus] = {}
121         self._report_file_name = report_file_name
122         self._report_file_handle = None
123
124     def setup_configure(self):
125         fname = self._report_file_name
126         if fname:
127             self._report_file_handle = open(fname, mode="w")
128
129     def setup_method_pre(self, item):
130         """Called before actually running the test setup_method"""
131         # Check if we need to manually drop the privileges
132         for mark in item.iter_markers():
133             if mark.name == "require_user":
134                 cls = self.get_test_class(item)
135                 cls.TARGET_USER = mark.args[0]
136                 break
137
138     def override_runtest(self, obj):
139         # Override basic runtest command
140         obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj)
141         # Override class setup/teardown
142         obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop
143         obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop
144
145     @staticmethod
146     def get_test_class(obj):
147         if hasattr(obj, "parent") and obj.parent is not None:
148             if hasattr(obj.parent, "cls"):
149                 return obj.parent.cls
150
151     def has_object_cleanup(self, obj):
152         cls = self.get_test_class(obj)
153         if cls is not None:
154             method_name = nodeid_to_method_name(obj.nodeid)
155             cleanup_name = "cleanup_{}".format(method_name)
156             if hasattr(cls, "cleanup") or hasattr(cls, cleanup_name):
157                 return True
158         return False
159
160     def _generate_test_cleanups(self, items):
161         new_items = []
162         for obj in items:
163             if self.has_object_cleanup(obj):
164                 self.override_runtest(obj)
165                 new_items.append(obj)
166         items.clear()
167         items.extend(new_items)
168
169     def expand_tests(self, collector, name, obj):
170         return generate_ktests(collector, name, obj)
171
172     def modify_tests(self, items, config):
173         if config.option.atf_cleanup:
174             self._generate_test_cleanups(items)
175
176     def list_tests(self, tests: List[str]):
177         print('Content-Type: application/X-atf-tp; version="1"')
178         print()
179         for test_obj in tests:
180             has_cleanup = self.has_object_cleanup(test_obj)
181             atf_test = ATFTestObj(test_obj, has_cleanup)
182             for line in atf_test.as_lines():
183                 print(line)
184             print()
185
186     def set_report_state(self, test_name: str, state: str, reason: str):
187         self._tests_state_map[test_name] = self.ReportState(state, reason)
188
189     def _extract_report_reason(self, report):
190         data = report.longrepr
191         if data is None:
192             return None
193         if isinstance(data, Tuple):
194             # ('/path/to/test.py', 23, 'Skipped: unable to test')
195             reason = data[2]
196             for prefix in "Skipped: ":
197                 if reason.startswith(prefix):
198                     reason = reason[len(prefix):]
199             return reason
200         else:
201             # string/ traceback / exception report. Capture the last line
202             return str(data).split("\n")[-1]
203         return None
204
205     def add_report(self, report):
206         # MAP pytest report state to the atf-desired state
207         #
208         # ATF test states:
209         # (1) expected_death, (2) expected_exit, (3) expected_failure
210         # (4) expected_signal, (5) expected_timeout, (6) passed
211         # (7) skipped, (8) failed
212         #
213         # Note that ATF don't have the concept of "soft xfail" - xpass
214         # is a failure. It also calls teardown routine in a separate
215         # process, thus teardown states (pytest-only) are handled as
216         # body continuation.
217
218         # (stage, state, wasxfail)
219
220         # Just a passing test: WANT: passed
221         # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F)
222         #
223         # Failing body test: WHAT: failed
224         # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F)
225         #
226         # pytest.skip test decorator: WANT: skipped
227         # GOT: (setup,skipped, False), (teardown, passed, False)
228         #
229         # pytest.skip call inside test function: WANT: skipped
230         # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F)
231         #
232         # mark.xfail decorator+pytest.xfail: WANT: expected_failure
233         # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F)
234         #
235         # mark.xfail decorator+pass: WANT: failed
236         # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F)
237
238         test_name = report.location[2]
239         stage = report.when
240         state = report.outcome
241         reason = self._extract_report_reason(report)
242
243         # We don't care about strict xfail - it gets translated to False
244
245         if stage == "setup":
246             if state in ("skipped", "failed"):
247                 # failed init -> failed test, skipped setup -> xskip
248                 # for the whole test
249                 self.set_report_state(test_name, state, reason)
250         elif stage == "call":
251             # "call" stage shouldn't matter if setup failed
252             if test_name in self._tests_state_map:
253                 if self._tests_state_map[test_name].state == "failed":
254                     return
255             if state == "failed":
256                 # Record failure  & override "skipped" state
257                 self.set_report_state(test_name, state, reason)
258             elif state == "skipped":
259                 if hasattr(reason, "wasxfail"):
260                     # xfail() called in the test body
261                     state = "expected_failure"
262                 else:
263                     # skip inside the body
264                     pass
265                 self.set_report_state(test_name, state, reason)
266             elif state == "passed":
267                 if hasattr(reason, "wasxfail"):
268                     # the test was expected to fail but didn't
269                     # mark as hard failure
270                     state = "failed"
271                 self.set_report_state(test_name, state, reason)
272         elif stage == "teardown":
273             if state == "failed":
274                 # teardown should be empty, as the cleanup
275                 # procedures should be implemented as a separate
276                 # function/method, so mark teardown failure as
277                 # global failure
278                 self.set_report_state(test_name, state, reason)
279
280     def write_report(self):
281         if self._report_file_handle is None:
282             return
283         if self._tests_state_map:
284             # If we're executing in ATF mode, there has to be just one test
285             # Anyway, deterministically pick the first one
286             first_test_name = next(iter(self._tests_state_map))
287             test = self._tests_state_map[first_test_name]
288             if test.state == "passed":
289                 line = test.state
290             else:
291                 line = "{}: {}".format(test.state, test.reason)
292             print(line, file=self._report_file_handle)
293         self._report_file_handle.close()
294
295     @staticmethod
296     def get_atf_vars() -> Dict[str, str]:
297         px = "_ATF_VAR_"
298         return {k[len(px):]: v for k, v in os.environ.items() if k.startswith(px)}