[utils] Support traversal helper functions require, value, unpack

Thx: yt-dlp/yt-dlp#10653
This commit is contained in:
dirkf
2025-10-31 13:36:55 +00:00
parent 96419fa706
commit 68fe8c1781
3 changed files with 63 additions and 4 deletions

View File

@@ -15,8 +15,11 @@ import re
from youtube_dl.traversal import (
dict_get,
get_first,
require,
T,
traverse_obj,
unpack,
value,
)
from youtube_dl.compat import (
compat_chr as chr,
@@ -27,7 +30,9 @@ from youtube_dl.compat import (
compat_zip as zip,
)
from youtube_dl.utils import (
ExtractorError,
int_or_none,
join_nonempty,
str_or_none,
)
@@ -462,8 +467,8 @@ class TestTraversal(_TestCase):
}),
values = dict((str(k), v) for k, v in values.items())
for key, value in values.items():
self.assertEqual(traverse_obj(morsel, key), value,
for key, val in values.items():
self.assertEqual(traverse_obj(morsel, key), val,
msg='Morsel should provide access to all values')
values = list(values.values())
self.assertMaybeCountEqual(traverse_obj(morsel, Ellipsis), values,
@@ -481,8 +486,31 @@ class TestTraversal(_TestCase):
[True, 1, 1.1, 'str', {0: 0}, [1]],
'`filter` should filter falsy values')
def test_get_first(self):
self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
class TestTraversalHelpers(_TestCase):
def test_traversal_require(self):
with self.assertRaises(ExtractorError, msg='Missing `value` should raise'):
traverse_obj(_TEST_DATA, ('None', T(require('value'))))
self.assertEqual(
traverse_obj(_TEST_DATA, ('str', T(require('value')))), 'str',
'`require` should pass through non-`None` values')
def test_unpack(self):
self.assertEqual(
unpack(lambda *x: ''.join(map(compat_str, x)))([1, 2, 3]), '123')
self.assertEqual(
unpack(join_nonempty)([1, 2, 3]), '1-2-3')
self.assertEqual(
unpack(join_nonempty, delim=' ')([1, 2, 3]), '1 2 3')
with self.assertRaises(TypeError):
unpack(join_nonempty)()
with self.assertRaises(TypeError):
unpack()
def test_value(self):
self.assertEqual(
traverse_obj(_TEST_DATA, ('str', T(value('other')))), 'other',
'`value` should substitute specified value')
class TestDictGet(_TestCase):
@@ -508,6 +536,9 @@ class TestDictGet(_TestCase):
self.assertEqual(dict_get(d, ('b', 'c', key, )), None)
self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value)
def test_get_first(self):
self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
if __name__ == '__main__':
unittest.main()