diff --git a/src/css_selectors/parse.py b/src/css_selectors/parse.py index ce4ee8e348..603c18aa51 100644 --- a/src/css_selectors/parse.py +++ b/src/css_selectors/parse.py @@ -14,7 +14,7 @@ import operator import string -from css_selectors.errors import SelectorSyntaxError +from css_selectors.errors import SelectorSyntaxError, ExpressionError if sys.version_info[0] < 3: _unicode = unicode @@ -159,6 +159,7 @@ def __init__(self, selector, name, arguments): self.selector = selector self.name = ascii_lower(name) self.arguments = arguments + self._parsed_arguments = None def __repr__(self): return '%s[%r:%s(%s)]' % ( @@ -168,6 +169,19 @@ def __repr__(self): def argument_types(self): return [token.type for token in self.arguments] + @property + def parsed_arguments(self): + if self._parsed_arguments is None: + try: + self._parsed_arguments = parse_series(self.arguments) + except ValueError: + raise ExpressionError("Invalid series: '%r'" % self.arguments) + return self._parsed_arguments + + def parse_arguments(self): + if not self.arguments_parsed: + self.arguments_parsed = True + def specificity(self): a, b, c = self.selector.specificity() b += 1 diff --git a/src/css_selectors/select.py b/src/css_selectors/select.py index 3da128fb67..2f60b5cef4 100644 --- a/src/css_selectors/select.py +++ b/src/css_selectors/select.py @@ -9,6 +9,7 @@ import re, itertools from collections import OrderedDict, defaultdict from functools import wraps +from itertools import chain from lxml import etree @@ -90,13 +91,16 @@ class Select(object): Tags are returned in document order. Note that attribute and tag names are matched case-insensitively. Also namespaces are ignored (this is for - performance of the common case). + performance of the common case). The UI related selectors are not + implemented, such as :enabled, :diabled, :checked, :hover, etc. Similarly, + the non-element related selectors such as ::first-line, ::first-letter, + ::before, etc. are not implemented. WARNING: This class uses internal caches. You *must not* make any changes to the lxml tree. If you do make some changes, either create a new Select object or call :meth:`invalidate_caches`. - This class can be easily sub-classes to work with tree implementations + This class can be easily sub-classed to work with tree implementations other than lxml. Simply override the methods in the ``Tree Integration`` block. @@ -135,6 +139,11 @@ def invalidate_caches(self): self._attrib_map = None self._attrib_space_map = None self._lang_map = None + self.map_tag_name = ascii_lower + if '{' in self.root.tag: + def map_tag_name(x): + return ascii_lower(x.rpartition('}')[2]) + self.map_tag_name = map_tag_name def __call__(self, selector): 'Return an iterator over all matching tags, in document order.' @@ -159,13 +168,8 @@ def iterparsedselector(self, parsed_selector): def element_map(self): if self._element_map is None: self._element_map = em = defaultdict(OrderedSet) - map_tag_name = ascii_lower - if '{' in self.root.tag: - def map_tag_name(x): - return ascii_lower(x.rpartition('}')[2]) - for tag in self.itertag(): - em[map_tag_name(tag.tag)].add(tag) + em[self.map_tag_name(tag.tag)].add(tag) return self._element_map @property @@ -251,6 +255,38 @@ def iteridtags(self): def iterclasstags(self): return get_compiled_xpath('//*[@class]')(self.root) + + def sibling_count(self, child, before=True, same_type=False): + ' Return the number of siblings before or after child or raise ValueError if child has no parent. ' + parent = child.getparent() + if parent is None: + raise ValueError('Child has no parent') + if same_type: + siblings = OrderedSet(child.itersiblings(preceding=before)) + return len(self.element_map[self.map_tag_name(child.tag)] & siblings) + else: + if before: + return parent.index(child) + return len(parent) - parent.index(child) - 1 + + def all_sibling_count(self, child, same_type=False): + ' Return the number of siblings of child or raise ValueError if child has no parent ' + parent = child.getparent() + if parent is None: + raise ValueError('Child has no parent') + if same_type: + siblings = OrderedSet(chain(child.itersiblings(preceding=False), child.itersiblings(preceding=True))) + return len(self.element_map[self.map_tag_name(child.tag)] & siblings) + else: + return len(parent) - 1 + + def is_empty(self, elem): + for child in elem: + # Check for comment/PI nodes with tail text + if child.tail: + return False + return len(tuple(elem.iterchildren('*'))) == 0 and not elem.text + # }}} # Combinators {{{ @@ -324,6 +360,13 @@ def select_class(cache, selector): if elem in items: yield elem +def select_negation(cache, selector): + 'Implement :not()' + exclude = frozenset(cache.iterparsedselector(selector.subselector)) + for item in cache.iterparsedselector(selector.selector): + if item not in exclude: + yield item + # Attribute selectors {{{ def select_attrib(cache, selector): @@ -381,17 +424,24 @@ def select_substringmatch(cache, attrib, value): def select_function(cache, function): """Select with a functional pseudo-class.""" + fname = function.name.replace('-', '_') try: - func = cache.dispatch_map[function.name.replace('-', '_')] + func = cache.dispatch_map[fname] except KeyError: raise ExpressionError( "The pseudo-class :%s() is unknown" % function.name) - items = frozenset(func(cache, function)) - for item in cache.iterparsedselector(function.selector): - if item in items: - yield item + if fname == 'lang': + items = frozenset(func(cache, function)) + for item in cache.iterparsedselector(function.selector): + if item in items: + yield item + else: + for item in cache.iterparsedselector(function.selector): + if func(cache, function, item): + yield item def select_lang(cache, function): + ' Implement :lang() ' if function.argument_types() not in (['STRING'], ['IDENT']): raise ExpressionError("Expected a single string or ident for :lang(), got %r" % function.arguments) lang = function.arguments[0].value @@ -403,12 +453,118 @@ def select_lang(cache, function): for elem in elem_set: yield elem +def select_nth_child(cache, function, elem): + ' Implement :nth-child() ' + a, b = function.parsed_arguments + try: + num = cache.sibling_count(elem) + 1 + except ValueError: + return False + if a == 0: + return num == b + n = (num - b) / a + return n.is_integer() and n > -1 + +def select_nth_last_child(cache, function, elem): + ' Implement :nth-last-child() ' + a, b = function.parsed_arguments + try: + num = cache.sibling_count(elem, before=False) + 1 + except ValueError: + return False + if a == 0: + return num == b + n = (num - b) / a + return n.is_integer() and n > -1 + +def select_nth_of_type(cache, function, elem): + ' Implement :nth-of-type() ' + a, b = function.parsed_arguments + try: + num = cache.sibling_count(elem, same_type=True) + 1 + except ValueError: + return False + if a == 0: + return num == b + n = (num - b) / a + return n.is_integer() and n > -1 + +def select_nth_last_of_type(cache, function, elem): + ' Implement :nth-last-of-type() ' + a, b = function.parsed_arguments + try: + num = cache.sibling_count(elem, before=False, same_type=True) + 1 + except ValueError: + return False + if a == 0: + return num == b + n = (num - b) / a + return n.is_integer() and n > -1 + +# }}} + +# Pseudo elements {{{ + +def select_pseudo(cache, pseudo): + if pseudo.ident == 'root': + yield cache.root + return + + try: + func = cache.dispatch_map[pseudo.ident.replace('-', '_')] + except KeyError: + raise ExpressionError( + "The pseudo-class :%s is not supported" % pseudo.ident) + + for item in cache.iterparsedselector(pseudo.selector): + if func(cache, item): + yield item + +def select_first_child(cache, elem): + try: + return cache.sibling_count(elem) == 0 + except ValueError: + return False + +def select_last_child(cache, elem): + try: + return cache.sibling_count(elem, before=False) == 0 + except ValueError: + return False + +def select_only_child(cache, elem): + try: + return cache.all_sibling_count(elem) == 0 + except ValueError: + return False + +def select_first_of_type(cache, elem): + try: + return cache.sibling_count(elem, same_type=True) == 0 + except ValueError: + return False + +def select_last_of_type(cache, elem): + try: + return cache.sibling_count(elem, before=False, same_type=True) == 0 + except ValueError: + return False + +def select_only_of_type(cache, elem): + try: + return cache.all_sibling_count(elem, same_type=True) == 0 + except ValueError: + return False + +def select_empty(cache, elem): + return cache.is_empty(elem) + # }}} default_dispatch_map = {name.partition('_')[2]:obj for name, obj in globals().items() if name.startswith('select_') and callable(obj)} if __name__ == '__main__': from pprint import pprint - root = etree.fromstring('
') + root = etree.fromstring('') select = Select(root, trace=True) - pprint(list(select('p a'))) + pprint(list(select('p *:root'))) diff --git a/src/css_selectors/tests.py b/src/css_selectors/tests.py index ad2d9ea05c..0a880d790f 100644 --- a/src/css_selectors/tests.py +++ b/src/css_selectors/tests.py @@ -6,14 +6,26 @@ __license__ = 'GPL v3' __copyright__ = '2015, Kovid Goyal