mirror of
https://github.com/ytdl-org/youtube-dl.git
synced 2025-12-08 15:12:43 +01:00
[utils] Support traversal helper functions require, value, unpack
Thx: yt-dlp/yt-dlp#10653
This commit is contained in:
@@ -15,8 +15,11 @@ import re
|
|||||||
from youtube_dl.traversal import (
|
from youtube_dl.traversal import (
|
||||||
dict_get,
|
dict_get,
|
||||||
get_first,
|
get_first,
|
||||||
|
require,
|
||||||
T,
|
T,
|
||||||
traverse_obj,
|
traverse_obj,
|
||||||
|
unpack,
|
||||||
|
value,
|
||||||
)
|
)
|
||||||
from youtube_dl.compat import (
|
from youtube_dl.compat import (
|
||||||
compat_chr as chr,
|
compat_chr as chr,
|
||||||
@@ -27,7 +30,9 @@ from youtube_dl.compat import (
|
|||||||
compat_zip as zip,
|
compat_zip as zip,
|
||||||
)
|
)
|
||||||
from youtube_dl.utils import (
|
from youtube_dl.utils import (
|
||||||
|
ExtractorError,
|
||||||
int_or_none,
|
int_or_none,
|
||||||
|
join_nonempty,
|
||||||
str_or_none,
|
str_or_none,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -462,8 +467,8 @@ class TestTraversal(_TestCase):
|
|||||||
}),
|
}),
|
||||||
values = dict((str(k), v) for k, v in values.items())
|
values = dict((str(k), v) for k, v in values.items())
|
||||||
|
|
||||||
for key, value in values.items():
|
for key, val in values.items():
|
||||||
self.assertEqual(traverse_obj(morsel, key), value,
|
self.assertEqual(traverse_obj(morsel, key), val,
|
||||||
msg='Morsel should provide access to all values')
|
msg='Morsel should provide access to all values')
|
||||||
values = list(values.values())
|
values = list(values.values())
|
||||||
self.assertMaybeCountEqual(traverse_obj(morsel, Ellipsis), values,
|
self.assertMaybeCountEqual(traverse_obj(morsel, Ellipsis), values,
|
||||||
@@ -481,8 +486,31 @@ class TestTraversal(_TestCase):
|
|||||||
[True, 1, 1.1, 'str', {0: 0}, [1]],
|
[True, 1, 1.1, 'str', {0: 0}, [1]],
|
||||||
'`filter` should filter falsy values')
|
'`filter` should filter falsy values')
|
||||||
|
|
||||||
def test_get_first(self):
|
|
||||||
self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
|
class TestTraversalHelpers(_TestCase):
|
||||||
|
def test_traversal_require(self):
|
||||||
|
with self.assertRaises(ExtractorError, msg='Missing `value` should raise'):
|
||||||
|
traverse_obj(_TEST_DATA, ('None', T(require('value'))))
|
||||||
|
self.assertEqual(
|
||||||
|
traverse_obj(_TEST_DATA, ('str', T(require('value')))), 'str',
|
||||||
|
'`require` should pass through non-`None` values')
|
||||||
|
|
||||||
|
def test_unpack(self):
|
||||||
|
self.assertEqual(
|
||||||
|
unpack(lambda *x: ''.join(map(compat_str, x)))([1, 2, 3]), '123')
|
||||||
|
self.assertEqual(
|
||||||
|
unpack(join_nonempty)([1, 2, 3]), '1-2-3')
|
||||||
|
self.assertEqual(
|
||||||
|
unpack(join_nonempty, delim=' ')([1, 2, 3]), '1 2 3')
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
unpack(join_nonempty)()
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
unpack()
|
||||||
|
|
||||||
|
def test_value(self):
|
||||||
|
self.assertEqual(
|
||||||
|
traverse_obj(_TEST_DATA, ('str', T(value('other')))), 'other',
|
||||||
|
'`value` should substitute specified value')
|
||||||
|
|
||||||
|
|
||||||
class TestDictGet(_TestCase):
|
class TestDictGet(_TestCase):
|
||||||
@@ -508,6 +536,9 @@ class TestDictGet(_TestCase):
|
|||||||
self.assertEqual(dict_get(d, ('b', 'c', key, )), None)
|
self.assertEqual(dict_get(d, ('b', 'c', key, )), None)
|
||||||
self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value)
|
self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value)
|
||||||
|
|
||||||
|
def test_get_first(self):
|
||||||
|
self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -5,6 +5,9 @@
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
dict_get,
|
dict_get,
|
||||||
get_first,
|
get_first,
|
||||||
|
require,
|
||||||
T,
|
T,
|
||||||
traverse_obj,
|
traverse_obj,
|
||||||
|
unpack,
|
||||||
|
value,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6543,6 +6543,31 @@ def traverse_obj(obj, *paths, **kwargs):
|
|||||||
return None if default is NO_DEFAULT else default
|
return None if default is NO_DEFAULT else default
|
||||||
|
|
||||||
|
|
||||||
|
def value(value):
|
||||||
|
return lambda _: value
|
||||||
|
|
||||||
|
|
||||||
|
class require(ExtractorError):
|
||||||
|
def __init__(self, name, expected=False):
|
||||||
|
super(require, self).__init__(
|
||||||
|
'Unable to extract {0}'.format(name), expected=expected)
|
||||||
|
|
||||||
|
def __call__(self, value):
|
||||||
|
if value is None:
|
||||||
|
raise self
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(func, **kwargs):
|
||||||
|
"""Make a function that applies `partial(func, **kwargs)` to its argument as *args"""
|
||||||
|
@functools.wraps(func)
|
||||||
|
def inner(items):
|
||||||
|
return func(*items, **kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def T(*x):
|
def T(*x):
|
||||||
""" For use in yt-dl instead of {type, ...} or set((type, ...)) """
|
""" For use in yt-dl instead of {type, ...} or set((type, ...)) """
|
||||||
return set(x)
|
return set(x)
|
||||||
|
|||||||
Reference in New Issue
Block a user