00001 ##################################################################################
00002 # Copyright (c) 2006  Gerard Flanagan
00003 #
00004 # Permission is hereby granted, free of charge, to any person obtaining
00005 # a copy of this software and associated documentation files (the "Software"),
00006 # to deal in the Software without restriction, including without limitation
00007 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
00008 # and/or sell copies of the Software, and to permit persons to whom the
00009 # Software is furnished to do so, subject to the following conditions:
00010 #
00011 #    The above copyright notice and this permission notice shall be included
00012 #    in all copies or substantial portions of the Software.
00013 #
00014 # Requires 'elementtree' which can be obtained from:
00015 #
00016 #   http://www.effbot.org/zone/element-index.htm
00017 #
00018 ##################################################################################
00019 
00020 import re
00021 import elementtree.ElementTree as ET
00022 
00023 RE_ATTRIBUTE = re.compile("(?<=@)[^\]]+$")
00024 
00025 RE_XPATH = re.compile("/({[^}]+})?([\w|*]+)?(?:\[(.+)\])?")
00026 
00027 class InvalidFilterException(Exception): pass
00028 
00029 class MultipleResultsException(Exception): pass
00030 
00031 def _find_elements_by_specification( element, specs):
00032     '''
00033     Recursively search subnodes of 'element' for those nodes which
00034     meet the specification associated with the level of recursion,
00035     eg. the second spec is associated with the grandchildren of 'element'.
00036     Returns the nodes found at the last level.
00037     '''
00038     if not specs:
00039         return element[:]
00040     else:
00041         nodes = (elem for elem in element[:] if specs[0].is_satisfied_by(elem))
00042         more_specs = specs[1:]
00043         if not more_specs:
00044             return  list(nodes)
00045         else:
00046         #search subnodes
00047             result = []
00048             for node in nodes:
00049                 result.extend( _find_elements_by_specification(node, more_specs) )
00050             return result
00051 
00052 def _remove_elements_by_specification(element, specs):
00053     '''
00054     Recursively search subnodes of 'element' for those nodes which
00055     meet the specification associated with the level of recursion,
00056     eg. the second spec is associated with the grandchildren of 'element'.
00057     Removes the nodes found at the last level.
00058     '''
00059     if not specs:
00060         return
00061     else:
00062         nodes = (elem for elem in element[:] if specs[0].is_satisfied_by(elem))
00063         more_specs = specs[1:]
00064         if not more_specs:
00065             for node in nodes:
00066                 element.remove( node )
00067         else:
00068             for node in nodes:
00069                 _remove_elements_by_specification(node, more_specs)
00070 
00071 class _NodeSpecification(object):
00072 
00073     def __init__(self, namespace=None, tag=None, filter=None):
00074         self.ns = namespace
00075         self.tag = tag or '*'
00076         self.filter = filter
00077 
00078     def is_satisfied_by(self, element):
00079         result = False
00080         if self.ns:
00081             if self.tag == '*' and element.tag.startswith(self.ns):
00082                 result = True
00083             tag = self.ns + self.tag
00084         else:
00085             tag = self.tag
00086         result = result or tag == '*' or tag == element.tag
00087         if result and self.filter:
00088             expr = self.filter
00089             for key,value in element.items():
00090                 key = '@' + key
00091                 try:
00092                     float(value)
00093                 except ValueError:
00094                     value = "\"%s\"" % value
00095                 expr = expr.replace( key, value )
00096             try:
00097                 result = eval(expr)
00098             except:
00099                 raise InvalidFilterException()
00100         return result
00101 
00102 def _filterpath_to_node_specification( filterpath ):
00103     '''
00104     >>> def parse_filter(path):
00105     ...     specs, attr =  _filterpath_to_node_specification(path)
00106     ...     return [(s.ns, s.tag, s.filter) for s in specs], attr
00107     ...
00108     >>> parse_filter('book[@author==Hopkins]/@barcode')
00109     ([('', 'book', '@author==Hopkins')], 'barcode')
00110     >>> parse_filter('books/*/title/author')
00111     ([('', 'books', ''), ('', '*', ''), ('', 'title', ''), ('', 'author', '')], None)
00112     >>> parse_filter('{ns}book[@price<20]/{ns}title')
00113     ([('{ns}', 'book', '@price<20'), ('{ns}', 'title', '')], None)
00114     >>> parse_filter('@id')
00115     ([], 'id')
00116     >>> parse_filter('@{ns}id')
00117     ([], '{ns}id')
00118     >>> parse_filter('{ns}@id')
00119     Traceback (most recent call last):
00120         ...
00121     InvalidFilterException
00122     >>> parse_filter('{ns}*')
00123     ([('{ns}', '*', '')], None)
00124     >>> parse_filter('{ns}*') == parse_filter('{ns}')
00125     True
00126     >>> parse_filter('{ns}/@id')
00127     ([('{ns}', '*', '')], 'id')
00128     >>> parse_filter('{ns}*[@name.startswith("A")]/@id')
00129     ([('{ns}', '*', '@name.startswith("A")')], 'id')
00130     '''
00131     specs = []
00132     #does the filterpath end with an attribute, eg. @id ?
00133     #Note: RE_ATTRIBUTE will produce id, not @id
00134     attr = re.search(RE_ATTRIBUTE, filterpath)
00135     if attr:
00136         attr = attr.group(0)
00137         #remove the attribute just found from the filterpath
00138         filterpath =  re.sub(RE_ATTRIBUTE, '', filterpath)
00139         #strip trailing @
00140         filterpath = filterpath[:-1]
00141         if filterpath and filterpath[-1] != '/': raise InvalidFilterException()
00142     else:
00143         attr = None
00144     if filterpath:
00145         #the filterpath passed to RE_XPATH.findall() must begin with a forward slash
00146         if filterpath[0] != '/':
00147             filterpath = '/' + filterpath
00148         #strip any trailing forward slash
00149         if filterpath[-1] == '/':
00150             filterpath = filterpath[:-1]
00151     for ns, tag, filter in RE_XPATH.findall(filterpath):
00152         specs.append( _NodeSpecification(ns, tag, filter) )
00153     return specs, attr
00154 
00155 def findall( element, filterpath ):
00156     '''
00157     >>> XML1, XML2 = _testdata()
00158     >>> [elem.text for elem in findall(XML1, "channel/item/title")]
00159     ['Normalizing XML, Part 2', 'The .NET Schema Object Model', "SVG's Past and Promising Future"]
00160     >>> [elem.text for elem in findall(XML1, "channel/item/creator")]
00161     []
00162     >>> #NB. James Clark format for namespace-qualified elements
00163     >>> [elem.text for elem in findall(XML1, "channel/item/{http://purl.org/dc/elements/1.1/}creator")]
00164     ['Will Provost', 'Priya Lakshminarayanan', 'Antoine Quint']
00165     >>> findall( XML2, 'category/item[@id=="A001"]/@colour' )
00166     ['red']
00167     >>> findall( XML2, 'category/item[@id.startswith("A")]/@colour' )
00168     ['red', 'blue', 'yellow']
00169     >>> findall( XML2, 'category[@id<500]/@id' )
00170     ['123', '456']
00171     '''
00172     return ElementFilter( element, filterpath ).findall()
00173 
00174 def removeall( element, filterpath ):
00175     '''
00176     >>> XML1, XML2 = _testdata()
00177     >>> filterpath = "channel/item/title"
00178     >>> [elem.text for elem in findall(XML1, filterpath)]
00179     ['Normalizing XML, Part 2', 'The .NET Schema Object Model', "SVG's Past and Promising Future"]
00180     >>> removeall(XML1, filterpath)
00181     >>> [elem.text for elem in findall(XML1, filterpath)]
00182     []
00183     >>> filter = ElementFilter( XML2, "category/item/@colour")
00184     >>> filter.findall()
00185     ['red', 'blue', 'yellow', 'pink', 'blue', 'green', 'pink', 'orange', 'blue']
00186     >>> filter.removeall()
00187     >>> filter.findall()
00188     [None, None, None, None, None, None, None, None, None]
00189     '''
00190     ElementFilter( element, filterpath ).removeall()
00191 
00192 def count( element, filterpath ):
00193     '''
00194     >>> XML1, XML2 = _testdata()
00195     >>> count(XML1, 'channel/item')
00196     3
00197     >>> count(XML2, 'category/item[@colour=="blue"]')
00198     3
00199     '''
00200     return ElementFilter( element, filterpath ).count()
00201 
00202 def data( element, filterpath ):
00203     '''
00204     >>> XML1, XML2 = _testdata()
00205     >>> data(XML1, "channel/item/title")
00206     ['Normalizing XML, Part 2', 'The .NET Schema Object Model', "SVG's Past and Promising Future"]
00207     >>> data(XML2, "category/item/@colour")
00208     ['red', 'blue', 'yellow', 'pink', 'blue', 'green', 'pink', 'orange', 'blue']
00209     '''
00210     return ElementFilter( element, filterpath ).data()
00211 
00212 def doc( element, filterpath, tag="root" ):
00213     '''
00214     Returns an ElementTree instance whose document root is a 'tag'
00215     element, and whose subnodes are the elements satisfying the filterpath
00216 
00217     >>> XML1, XML2 = _testdata()
00218     >>> filterpath = 'category/item[@colour=="blue"]'
00219     >>> root = doc(XML2, filterpath).getroot()
00220     >>> len(root[:])
00221     3
00222     '''
00223     return ElementFilter(element, filterpath).doc(tag)
00224 
00225 def replace( element, filterpath, old, new, count=-1):
00226     '''
00227     Replace strings within text elements or attributes
00228 
00229     >>> XML1, XML2 = _testdata()
00230     >>> filter = ElementFilter(XML1, "channel/title")
00231     >>> filter.data()
00232     ['XML.com']
00233     >>> filter.replace('XML', 'xml')
00234     >>> filter.data()
00235     ['xml.com']
00236     >>> filter = ElementFilter(XML2, "category/item/@colour")
00237     >>> filter.data()
00238     ['red', 'blue', 'yellow', 'pink', 'blue', 'green', 'pink', 'orange', 'blue']
00239     >>> filter.replace('pink', 'lilac')
00240     >>> filter.data()
00241     ['red', 'blue', 'yellow', 'lilac', 'blue', 'green', 'lilac', 'orange', 'blue']
00242     '''
00243     ElementFilter(element, filterpath).replace(old, new, count)
00244 
00245 def text(element, filterpath):
00246     '''
00247     >>> XML1, XML2 = _testdata()
00248     >>> text(XML1, 'channel/title')
00249     'XML.com'
00250     >>> text(XML1, 'channel/item')
00251     Traceback (most recent call last):
00252        ...
00253     MultipleResultsException
00254     '''
00255     return ElementFilter(element, filterpath).text()
00256 
00257 class ElementFilter(object):
00258 
00259     attribute_default = None
00260 
00261     def __get_filter(self):
00262         return self.__filter
00263 
00264     def __set_filter(self, filter):
00265         self.__filter = filter
00266         self.specs, self.attribute  = _filterpath_to_node_specification(filter)
00267         self._cache = None
00268         self._doc = None
00269         self._count = None
00270 
00271     filter = property( __get_filter,__set_filter )
00272 
00273     def __get_filtered(self):
00274         if self._cache is None:
00275             self._cache = _find_elements_by_specification( self.element, self.specs )
00276         return self._cache
00277 
00278     filtered = property( __get_filtered )
00279 
00280     def __init__(self, element, filterpath=''):
00281         self.element = element
00282         self.__set_filter( filterpath )
00283 
00284     def findall(self):
00285         if self.attribute is None:
00286             return self.filtered
00287         else:
00288             #note that the presence of a default value for a non-present key means that
00289             #for attributes: self.count() = len(self.data()) <= len(self.findall())
00290             return  [ e.get(self.attribute, self.attribute_default) for e in self.filtered ]
00291 
00292     def removeall(self):
00293         if self.attribute is None:
00294             _remove_elements_by_specification( self.element, self.specs )
00295         else:
00296             for elem in self.filtered:
00297                 if self.attribute in elem.attrib:
00298                     del elem.attrib[self.attribute]
00299 
00300     def count(self):
00301         if self._count is None:
00302             if self.attribute is None:
00303                 self._count = len(self.filtered)
00304             else:
00305                 self._count = len(self.data())
00306         return self._count
00307 
00308     def data(self):
00309         if self.attribute is None:
00310             return [ elem.text for elem in self.filtered ]
00311         else:
00312             return [ attr for attr in self.findall() if attr is not self.attribute_default ]
00313 
00314     def distinct_values(self):
00315         return set(self.data())
00316 
00317     def doc(self, tag="root"):
00318         if self._doc is None:
00319             root = ET.Element(tag)
00320             for elem in self.filtered:
00321                 root.append(elem)
00322             self._doc = ET.ElementTree(root)
00323         return self._doc
00324 
00325     def empty(self):
00326         return bool(self.count())
00327 
00328     def replace(self, old, new, count=-1):
00329         for elem in self.filtered:
00330             if self.attribute is None:
00331                 elem.text = elem.text.replace(old, new, count)
00332                 elem.tail = elem.tail.replace(old, new, count)
00333             else:
00334                 oldval = elem.get(self.attribute)
00335                 if oldval:
00336                     elem.set(self.attribute, oldval.replace(old, new, count))
00337 
00338     def zero_or_one(self):
00339         result = self.findall()
00340         if len(result) > 1:
00341             raise MultipleResultsException()
00342         return result
00343 
00344     def text(self):
00345         result = self.zero_or_one()
00346         if not result:
00347             return None
00348         if self.attribute is None:
00349             return result[0].text
00350         else:
00351             return result[0]
00352 
00353 def _test():
00354     import doctest
00355     doctest.testmod()
00356 
00357 def _testdata():
00358     return ET.fromstring(RSS_TEST), ET.fromstring(ATTR_TEST)
00359 
00360 RSS_TEST = ''' <rss version="2.0" xmlns:dc="http://purl.org/dc/elements/1.1/">
00361   <channel encoding="utf-8">
00362     <title>XML.com</title>
00363     <link>http://www.xml.com/</link>
00364     <description>XML.com features a rich mix of information and services for the XML community.</description>
00365     <language>en-us</language>
00366     <item>
00367       <title>Normalizing XML, Part 2</title>
00368       <link>http://www.xml.com/pub/a/2002/12/04/normalizing.html</link>
00369       <description>In this second and final look at applying relational normalization techniques to W3C XML Schema data modeling, Will Provost discusses when not to normalize, the scope of uniqueness and the fourth and fifth normal forms.</description>
00370       <dc:creator>Will Provost</dc:creator>
00371       <dc:date>2002-12-04</dc:date>
00372     </item>
00373     <item>
00374       <title>The .NET Schema Object Model</title>
00375       <link>http://www.xml.com/pub/a/2002/12/04/som.html</link>
00376       <description>Priya Lakshminarayanan describes in detail the use of the .NET Schema Object Model for programmatic manipulation of W3C XML Schemas.</description>
00377       <dc:creator>Priya Lakshminarayanan</dc:creator>
00378       <dc:date>2002-12-04</dc:date>
00379     </item>
00380     <item>
00381       <title>SVG's Past and Promising Future</title>
00382       <link>http://www.xml.com/pub/a/2002/12/04/svg.html</link>
00383       <description>In this month's SVG column, Antoine Quint looks back at SVG's journey through 2002 and looks forward to 2003.</description>
00384       <dc:creator>Antoine Quint</dc:creator>
00385       <dc:date>2002-12-04</dc:date>
00386     </item>
00387   </channel>
00388 </rss>'''
00389 
00390 ATTR_TEST = '''<root>
00391     <title lang="en" encoding="utf-8">Document Title</title>
00392     <category id="123" code="A">
00393         <item id="A001" colour="red">Item A1</item>
00394         <item id="A002" colour="blue">Item A2</item>
00395         <item id="A003" colour="yellow">Item A3</item>
00396     </category>
00397     <category id="456" code="B">
00398         <item id="B001" colour="pink">Item B1</item>
00399         <item id="B002" colour="blue">Item B2</item>
00400         <item id="B003" colour="green">Item B3</item>
00401     </category>
00402     <category id="789" code="C">
00403         <item id="C001" colour="pink">Item C1</item>
00404         <item id="C002" colour="orange">Item C2</item>
00405         <item id="C003" colour="blue">Item C3</item>
00406     </category>
00407 </root>'''
00408 
00409 if __name__ == '__main__':
00410     _test()