1'''
2This module contains the classes which represent XCB data types.
3'''
4import sys
5from xcbgen.expr import Field, Expression
6from xcbgen.align import Alignment, AlignmentLog
7
8if sys.version_info[:2] >= (3, 3):
9    from xml.etree.ElementTree import SubElement
10else:
11    from xml.etree.cElementTree import SubElement
12
13import __main__
14
15verbose_align_log = False
16true_values = ['true', '1', 'yes']
17
18class Type(object):
19    '''
20    Abstract base class for all XCB data types.
21    Contains default fields, and some abstract methods.
22    '''
23    def __init__(self, name):
24        '''
25        Default structure initializer.  Sets up default fields.
26
27        Public fields:
28        name is a tuple of strings specifying the full type name.
29        size is the size of the datatype in bytes, or None if variable-sized.
30        nmemb is 1 for non-list types, None for variable-sized lists, otherwise number of elts.
31        booleans for identifying subclasses, because I can't figure out isinstance().
32        '''
33        self.name = name
34        self.size = None
35        self.nmemb = None
36        self.resolved = False
37
38        # Screw isinstance().
39        self.is_simple = False
40        self.is_list = False
41        self.is_expr = False
42        self.is_container = False
43        self.is_reply = False
44        self.is_union = False
45        self.is_pad = False
46        self.is_eventstruct = False
47        self.is_event = False
48        self.is_switch = False
49        self.is_case_or_bitcase = False
50        self.is_bitcase = False
51        self.is_case = False
52        self.is_fd = False
53        self.required_start_align = Alignment()
54
55        # the biggest align value of an align-pad contained in this type
56        self.max_align_pad = 1
57
58    def resolve(self, module):
59        '''
60        Abstract method for resolving a type.
61        This should make sure any referenced types are already declared.
62        '''
63        raise Exception('abstract resolve method not overridden!')
64
65    def out(self, name):
66        '''
67        Abstract method for outputting code.
68        These are declared in the language-specific modules, and
69        there must be a dictionary containing them declared when this module is imported!
70        '''
71        raise Exception('abstract out method not overridden!')
72
73    def fixed_size(self):
74        '''
75        Abstract method for determining if the data type is fixed-size.
76        '''
77        raise Exception('abstract fixed_size method not overridden!')
78
79    def make_member_of(self, module, complex_type, field_type, field_name, visible, wire, auto, enum=None, is_fd=False):
80        '''
81        Default method for making a data type a member of a structure.
82        Extend this if the data type needs to add an additional length field or something.
83
84        module is the global module object.
85        complex_type is the structure object.
86        see Field for the meaning of the other parameters.
87        '''
88        new_field = Field(self, field_type, field_name, visible, wire, auto, enum, is_fd)
89
90        # We dump the _placeholder_byte if any fields are added.
91        for (idx, field) in enumerate(complex_type.fields):
92            if field == _placeholder_byte:
93                complex_type.fields[idx] = new_field
94                return
95
96        complex_type.fields.append(new_field)
97        new_field.parent = complex_type
98
99    def make_fd_of(self, module, complex_type, fd_name):
100        '''
101        Method for making a fd member of a structure.
102        '''
103        new_fd = Field(self, module.get_type_name('INT32'), fd_name, True, False, False, None, True)
104        # We dump the _placeholder_byte if any fields are added.
105        for (idx, field) in enumerate(complex_type.fields):
106            if field == _placeholder_byte:
107                complex_type.fields[idx] = new_fd
108                return
109
110        complex_type.fields.append(new_fd)
111
112
113    def get_total_size(self):
114        '''
115        get the total size of this type if it is fixed-size, otherwise None
116        '''
117        if self.fixed_size():
118            if self.nmemb is None:
119                return self.size
120            else:
121                return self.size * self.nmemb
122        else:
123            return None
124
125    def get_align_offset(self):
126        if self.required_start_align is None:
127            return 0
128        else:
129            return self.required_start_align.offset
130
131    def is_acceptable_start_align(self, start_align, callstack, log):
132        return self.get_alignment_after(start_align, callstack, log) is not None
133
134    def get_alignment_after(self, start_align, callstack, log):
135        '''
136        get the alignment after this type based on the given start_align.
137        the start_align is checked for compatibility with the
138        internal start align. If it is not compatible, then None is returned
139        '''
140        if self.required_start_align is None or self.required_start_align.is_guaranteed_at(start_align):
141            return self.unchecked_get_alignment_after(start_align, callstack, log)
142        else:
143            if log is not None:
144                log.fail(start_align, "", self, callstack + [self],
145                    "start_align is incompatible with required_start_align %s"
146                    % (str(self.required_start_align)))
147            return None
148
149    def unchecked_get_alignment_after(self, start_align, callstack, log):
150        '''
151        Abstract method for geting the alignment after this type
152        when the alignment at the start is given, and when this type
153        has variable size.
154        '''
155        raise Exception('abstract unchecked_get_alignment_after method not overridden!')
156
157
158    @staticmethod
159    def type_name_to_str(type_name):
160        if isinstance(type_name, str):
161            #already a string
162            return type_name
163        else:
164            return ".".join(type_name)
165
166
167    def __str__(self):
168        return type(self).__name__ + " \"" + Type.type_name_to_str(self.name) + "\""
169
170class PrimitiveType(Type):
171
172    def __init__(self, name, size):
173        Type.__init__(self, name)
174        self.size = size
175        self.nmemb = 1
176
177        # compute the required start_alignment based on the size of the type
178        self.required_start_align = Alignment.for_primitive_type(self.size)
179
180    def unchecked_get_alignment_after(self, start_align, callstack, log):
181        my_callstack = callstack + [self];
182        after_align = start_align.align_after_fixed_size(self.size)
183
184        if log is not None:
185            if after_align is None:
186                log.fail(start_align, "", self, my_callstack,
187                "align after fixed size %d failed" % self.size)
188            else:
189                log.ok(start_align, "", self, my_callstack, after_align)
190
191        return after_align
192
193    def fixed_size(self):
194        return True
195
196class SimpleType(PrimitiveType):
197    '''
198    Derived class which represents a cardinal type like CARD32 or char.
199    Any type which is typedef'ed to cardinal will be one of these.
200
201    Public fields added:
202    xml_type is the original string describing the type in the XML
203    '''
204    def __init__(self, name, size, xml_type=None):
205        PrimitiveType.__init__(self, name, size)
206        self.is_simple = True
207        self.xml_type = xml_type
208
209    def resolve(self, module):
210        self.resolved = True
211
212    out = __main__.output['simple']
213
214
215# Cardinal datatype globals.  See module __init__ method.
216tcard8 = SimpleType(('uint8_t',), 1, 'CARD8')
217tcard16 = SimpleType(('uint16_t',), 2, 'CARD16')
218tcard32 = SimpleType(('uint32_t',), 4, 'CARD32')
219tcard64 = SimpleType(('uint64_t',), 8, 'CARD64')
220tint8 =  SimpleType(('int8_t',), 1, 'INT8')
221tint16 = SimpleType(('int16_t',), 2, 'INT16')
222tint32 = SimpleType(('int32_t',), 4, 'INT32')
223tint64 = SimpleType(('int64_t',), 8, 'INT64')
224tchar =  SimpleType(('char',), 1, 'char')
225tfloat = SimpleType(('float',), 4, 'float')
226tdouble = SimpleType(('double',), 8, 'double')
227tbyte = SimpleType(('uint8_t',), 1, 'BYTE')
228tbool = SimpleType(('uint8_t',), 1, 'BOOL')
229tvoid = SimpleType(('uint8_t',), 1, 'void')
230
231class FileDescriptor(SimpleType):
232    '''
233    Derived class which represents a file descriptor.
234    '''
235    def __init__(self):
236        SimpleType.__init__(self, ('int'), 4, 'fd')
237        self.is_fd = True
238
239    def fixed_size(self):
240        return True
241
242    out = __main__.output['simple']
243
244class Enum(SimpleType):
245    '''
246    Derived class which represents an enum.  Fixed-size.
247
248    Public fields added:
249    values contains a list of (name, value) tuples.  value is empty, or a number.
250    bits contains a list of (name, bitnum) tuples.  items only appear if specified as a bit. bitnum is a number.
251    '''
252    def __init__(self, name, elt):
253        SimpleType.__init__(self, name, 4, 'enum')
254        self.values = []
255        self.bits = []
256        self.doc = None
257        for item in list(elt):
258            if item.tag == 'doc':
259                self.doc = Doc(name, item)
260
261            # First check if we're using a default value
262            if len(list(item)) == 0:
263                self.values.append((item.get('name'), ''))
264                continue
265
266            # An explicit value or bit was specified.
267            value = list(item)[0]
268            if value.tag == 'value':
269                self.values.append((item.get('name'), value.text))
270            elif value.tag == 'bit':
271                self.values.append((item.get('name'), '%u' % (1 << int(value.text, 0))))
272                self.bits.append((item.get('name'), value.text))
273
274    def resolve(self, module):
275        self.resolved = True
276
277    def fixed_size(self):
278        return True
279
280    out = __main__.output['enum']
281
282
283class ListType(Type):
284    '''
285    Derived class which represents a list of some other datatype.  Fixed- or variable-sized.
286
287    Public fields added:
288    member is the datatype of the list elements.
289    parent is the structure type containing the list.
290    expr is an Expression object containing the length information, for variable-sized lists.
291    '''
292    def __init__(self, elt, member, *parent):
293        Type.__init__(self, member.name)
294        self.is_list = True
295        self.member = member
296        self.parents = list(parent)
297        lenfield_name = False
298
299        if elt.tag == 'list':
300            elts = list(elt)
301            self.expr = Expression(elts[0] if len(elts) else elt, self)
302            is_list_in_parent = self.parents[0].elt.tag in ('request', 'event', 'reply', 'error')
303            if not len(elts) and is_list_in_parent:
304                self.expr = Expression(elt,self)
305                self.expr.op = 'calculate_len'
306            else:
307                self.expr = Expression(elts[0] if len(elts) else elt, self)
308
309        self.size = member.size if member.fixed_size() else None
310        self.nmemb = self.expr.nmemb if self.expr.fixed_size() else None
311
312        self.required_start_align = self.member.required_start_align
313
314    def make_member_of(self, module, complex_type, field_type, field_name, visible, wire, auto, enum=None):
315        if not self.fixed_size():
316            # We need a length field.
317            # Ask our Expression object for it's name, type, and whether it's on the wire.
318            lenfid = self.expr.lenfield_type
319            lenfield_name = self.expr.lenfield_name
320            lenwire = self.expr.lenwire
321            needlen = True
322
323            # See if the length field is already in the structure.
324            for parent in self.parents:
325                for field in parent.fields:
326                    if field.field_name == lenfield_name:
327                        needlen = False
328
329            # It isn't, so we need to add it to the structure ourself.
330            if needlen:
331                type = module.get_type(lenfid)
332                lenfield_type = module.get_type_name(lenfid)
333                type.make_member_of(module, complex_type, lenfield_type, lenfield_name, True, lenwire, False, enum)
334
335        # Add ourself to the structure by calling our original method.
336        if self.member.is_fd:
337            wire = False
338        Type.make_member_of(self, module, complex_type, field_type, field_name, visible, wire, auto, enum, self.member.is_fd)
339
340    def resolve(self, module):
341        if self.resolved:
342            return
343        self.member.resolve(module)
344        self.expr.resolve(module, self.parents)
345
346        # resolve() could have changed the size (ComplexType starts with size 0)
347        self.size = self.member.size if self.member.fixed_size() else None
348
349        self.required_start_align = self.member.required_start_align
350
351        # Find my length field again.  We need the actual Field object in the expr.
352        # This is needed because we might have added it ourself above.
353        if not self.fixed_size():
354            for parent in self.parents:
355                for field in parent.fields:
356                    if field.field_name == self.expr.lenfield_name and field.wire:
357                        self.expr.lenfield = field
358                        break
359
360        self.resolved = True
361
362    def fixed_size(self):
363        return self.member.fixed_size() and self.expr.fixed_size()
364
365    def unchecked_get_alignment_after(self, start_align, callstack, log):
366        my_callstack = callstack[:]
367        my_callstack.append(self)
368        if start_align is None:
369            log.fail(start_align, "", self, my_callstack, "start_align is None")
370            return None
371        if self.expr.fixed_size():
372            # fixed number of elements
373            num_elements = self.nmemb
374            prev_alignment = None
375            alignment = start_align
376            while num_elements > 0:
377                if alignment is None:
378                    if log is not None:
379                        log.fail(start_align, "", self, my_callstack,
380                            ("fixed size list with size %d after %d iterations"
381                            + ", at transition from alignment \"%s\"")
382                            % (self.nmemb,
383                               (self.nmemb - num_elements),
384                               str(prev_alignment)))
385                    return None
386                prev_alignment = alignment
387                alignment = self.member.get_alignment_after(prev_alignment, my_callstack, log)
388                num_elements -= 1
389            if log is not None:
390                log.ok(start_align, "", self, my_callstack, alignment)
391            return alignment
392        else:
393            # variable number of elements
394            # check whether the number of elements is a multiple
395            multiple = self.expr.get_multiple()
396            assert multiple > 0
397
398            # iterate until the combined alignment does not change anymore
399            alignment = start_align
400            while True:
401                prev_multiple_alignment = alignment
402                # apply "multiple" amount of changes sequentially
403                prev_alignment = alignment
404                for multiple_count in range(0, multiple):
405
406                    after_alignment = self.member.get_alignment_after(prev_alignment, my_callstack, log)
407                    if after_alignment is None:
408                        if log is not None:
409                            log.fail(start_align, "", self, my_callstack,
410                                ("variable size list "
411                                + "at transition from alignment \"%s\"")
412                                % (str(prev_alignment)))
413                        return None
414
415                    prev_alignment = after_alignment
416
417                # combine with the cumulatively combined alignment
418                # (to model the variable number of entries)
419                alignment = prev_multiple_alignment.combine_with(after_alignment)
420
421                if alignment == prev_multiple_alignment:
422                    # does not change anymore by adding more potential elements
423                    # -> finished
424                    if log is not None:
425                        log.ok(start_align, "", self, my_callstack, alignment)
426                    return alignment
427
428class ExprType(PrimitiveType):
429    '''
430    Derived class which represents an exprfield.  Fixed size.
431
432    Public fields added:
433    expr is an Expression object containing the value of the field.
434    '''
435    def __init__(self, elt, member, *parents):
436        PrimitiveType.__init__(self, member.name, member.size)
437        self.is_expr = True
438        self.member = member
439        self.parents = parents
440
441        self.expr = Expression(list(elt)[0], self)
442
443    def resolve(self, module):
444        if self.resolved:
445            return
446        self.member.resolve(module)
447        self.resolved = True
448
449
450class PadType(Type):
451    '''
452    Derived class which represents a padding field.
453    '''
454    def __init__(self, elt):
455        Type.__init__(self, tcard8.name)
456        self.is_pad = True
457        self.size = 1
458        self.nmemb = 1
459        self.align = 1
460        if elt != None:
461            self.nmemb = int(elt.get('bytes', "1"), 0)
462            self.align = int(elt.get('align', "1"), 0)
463            self.serialize = elt.get('serialize', "false").lower() in true_values
464
465        # pads don't require any alignment at their start
466        self.required_start_align = Alignment(1,0)
467
468    def resolve(self, module):
469        self.resolved = True
470
471    def fixed_size(self):
472        return self.align <= 1
473
474    def unchecked_get_alignment_after(self, start_align, callstack, log):
475        if self.align <= 1:
476            # fixed size pad
477            after_align = start_align.align_after_fixed_size(self.get_total_size())
478            if log is not None:
479                if after_align is None:
480                    log.fail(start_align, "", self, callstack,
481                    "align after fixed size pad of size %d failed" % self.size)
482                else:
483                    log.ok(start_align, "", self, callstack, after_align)
484
485            return after_align
486
487        # align-pad
488        assert self.align > 1
489        assert self.size == 1
490        assert self.nmemb == 1
491        if (start_align.offset == 0
492           and self.align <= start_align.align
493           and start_align.align % self.align == 0):
494            # the alignment pad is size 0 because the start_align
495            # is already sufficiently aligned -> return the start_align
496            after_align = start_align
497        else:
498            # the alignment pad has nonzero size -> return the alignment
499            # that is guaranteed by it, independently of the start_align
500            after_align = Alignment(self.align, 0)
501
502        if log is not None:
503            log.ok(start_align, "", self, callstack, after_align)
504
505        return after_align
506
507class ComplexType(Type):
508    '''
509    Derived class which represents a structure.  Base type for all structure types.
510
511    Public fields added:
512    fields is an array of Field objects describing the structure fields.
513    length_expr is an expression that defines the length of the structure.
514
515    '''
516    def __init__(self, name, elt):
517        Type.__init__(self, name)
518        self.is_container = True
519        self.elt = elt
520        self.fields = []
521        self.nmemb = 1
522        self.size = 0
523        self.lenfield_parent = [self]
524        self.length_expr = None
525
526        # get required_start_alignment
527        required_start_align_element = elt.find("required_start_align")
528        if required_start_align_element is None:
529            # unknown -> mark for autocompute
530            self.required_start_align = None
531        else:
532            self.required_start_align = Alignment(
533                int(required_start_align_element.get('align', "4"), 0),
534                int(required_start_align_element.get('offset', "0"), 0))
535            if verbose_align_log:
536                print ("Explicit start-align for %s: %s\n" % (self, self.required_start_align))
537
538    def resolve(self, module):
539        if self.resolved:
540            return
541
542        # Resolve all of our field datatypes.
543        for child in list(self.elt):
544            enum = None
545            if child.tag == 'pad':
546                field_name = 'pad' + str(module.pads)
547                fkey = 'CARD8'
548                type = PadType(child)
549                module.pads = module.pads + 1
550                visible = False
551            elif child.tag == 'field':
552                field_name = child.get('name')
553                enum = child.get('enum')
554                fkey = child.get('type')
555                type = module.get_type(fkey)
556                visible = True
557            elif child.tag == 'exprfield':
558                field_name = child.get('name')
559                fkey = child.get('type')
560                type = ExprType(child, module.get_type(fkey), *self.lenfield_parent)
561                visible = False
562            elif child.tag == 'list':
563                field_name = child.get('name')
564                fkey = child.get('type')
565                if fkey == 'fd':
566                    ftype = FileDescriptor()
567                    fkey = 'INT32'
568                else:
569                    ftype = module.get_type(fkey)
570                type = ListType(child, ftype, *self.lenfield_parent)
571                visible = True
572            elif child.tag == 'switch':
573                field_name = child.get('name')
574                # construct the switch type name from the parent type and the field name
575                field_type = self.name + (field_name,)
576                type = SwitchType(field_type, child, *self.lenfield_parent)
577                visible = True
578                type.make_member_of(module, self, field_type, field_name, visible, True, False)
579                type.resolve(module)
580                continue
581            elif child.tag == 'fd':
582                fd_name = child.get('name')
583                type = module.get_type('INT32')
584                type.make_fd_of(module, self, fd_name)
585                continue
586            elif child.tag == 'length':
587                self.length_expr = Expression(list(child)[0], self)
588                continue
589            else:
590                # Hit this on Reply
591                continue
592
593            # Get the full type name for the field
594            field_type = module.get_type_name(fkey)
595            # Add the field to ourself
596            type.make_member_of(module, self, field_type, field_name, visible, True, False, enum)
597            # Recursively resolve the type (could be another structure, list)
598            type.resolve(module)
599
600            # Compute the size of the maximally contain align-pad
601            if type.max_align_pad > self.max_align_pad:
602                self.max_align_pad = type.max_align_pad
603
604        self.check_implicit_fixed_size_part_aligns();
605
606        self.calc_size() # Figure out how big we are
607        self.calc_or_check_required_start_align()
608
609        self.resolved = True
610
611    def calc_size(self):
612        self.size = 0
613        for m in self.fields:
614            if not m.wire:
615                continue
616            if m.type.fixed_size():
617                self.size = self.size + m.type.get_total_size()
618            else:
619                self.size = None
620                break
621
622    def calc_or_check_required_start_align(self):
623        if self.required_start_align is None:
624            # no required-start-align configured -> calculate it
625            log = AlignmentLog()
626            callstack = []
627            self.required_start_align = self.calc_minimally_required_start_align(callstack, log)
628            if self.required_start_align is None:
629                print ("ERROR: could not calc required_start_align of %s\nDetails:\n%s"
630                    % (str(self), str(log)))
631            else:
632                if verbose_align_log:
633                    print ("calc_required_start_align: %s has start-align %s"
634                        % (str(self), str(self.required_start_align)))
635                    print ("Details:\n" + str(log))
636                if self.required_start_align.offset != 0:
637                    print (("WARNING: %s\n\thas start-align with non-zero offset: %s"
638                        + "\n\tsuggest to add explicit definition with:"
639                        + "\n\t\t<required_start_align align=\"%d\" offset=\"%d\" />"
640                        + "\n\tor to fix the xml so that zero offset is ok\n")
641                        % (str(self), self.required_start_align,
642                           self.required_start_align.align,
643                           self.required_start_align.offset))
644        else:
645            # required-start-align configured -> check it
646            log = AlignmentLog()
647            callstack = []
648            if not self.is_possible_start_align(self.required_start_align, callstack, log):
649                print ("ERROR: required_start_align %s of %s causes problems\nDetails:\n%s"
650                    % (str(self.required_start_align), str(self), str(log)))
651
652
653    def calc_minimally_required_start_align(self, callstack, log):
654        # calculate the minimally required start_align that causes no
655        # align errors
656        best_log = None
657        best_failed_align = None
658        for align in [1,2,4,8]:
659            for offset in range(0,align):
660                align_candidate = Alignment(align, offset)
661                if verbose_align_log:
662                    print ("trying %s for %s" % (str(align_candidate), str(self)))
663                my_log = AlignmentLog()
664                if self.is_possible_start_align(align_candidate, callstack, my_log):
665                    log.append(my_log)
666                    if verbose_align_log:
667                        print ("found start-align %s for %s" % (str(align_candidate), str(self)))
668                    return align_candidate
669                else:
670                    my_ok_count = my_log.ok_count()
671                    if (best_log is None
672                       or my_ok_count > best_log.ok_count()
673                       or (my_ok_count == best_log.ok_count()
674                          and align_candidate.align > best_failed_align.align)
675                          and align_candidate.align != 8):
676                        best_log = my_log
677                        best_failed_align = align_candidate
678
679
680
681        # none of the candidates applies
682        # this type has illegal internal aligns for all possible start_aligns
683        if verbose_align_log:
684            print ("didn't find start-align for %s" % str(self))
685        log.append(best_log)
686        return None
687
688    def is_possible_start_align(self, align, callstack, log):
689        if align is None:
690            return False
691        if (self.max_align_pad > align.align
692           or align.align % self.max_align_pad != 0):
693            # our align pad implementation depends on known alignment
694            # at the start of our type
695            return False
696
697        return self.get_alignment_after(align, callstack, log) is not None
698
699    def fixed_size(self):
700        for m in self.fields:
701            if not m.type.fixed_size():
702                return False
703        return True
704
705
706    # default impls of polymorphic methods which assume sequential layout of fields
707    # (like Struct or CaseOrBitcaseType)
708    def check_implicit_fixed_size_part_aligns(self):
709        # find places where the implementation of the C-binding would
710        # create code that makes the compiler add implicit alignment.
711        # make these places explicit, so we have
712        # consistent behaviour for all bindings
713        size = 0
714        for field in self.fields:
715            if not field.wire:
716                continue
717            if not field.type.fixed_size():
718                # end of fixed-size part
719                break
720            required_field_align = field.type.required_start_align
721            if required_field_align is None:
722                raise Exception(
723                    "field \"%s\" in \"%s\" has not required_start_align"
724                    % (field.field_name, self.name)
725                )
726            mis_align = (size + required_field_align.offset) % required_field_align.align
727            if mis_align != 0:
728                # implicit align pad is required
729                padsize = required_field_align.align - mis_align
730                raise Exception(
731                    "C-compiler would insert implicit alignpad of size %d before field \"%s\" in \"%s\""
732                    % (padsize, field.field_name, self.name)
733                )
734
735    def unchecked_get_alignment_after(self, start_align, callstack, log):
736        # default impl assumes sequential layout of fields
737        # (like Struct or CaseOrBitcaseType)
738        my_align = start_align
739        if my_align is None:
740            return None
741
742        for field in self.fields:
743            if not field.wire:
744                continue
745            my_callstack = callstack[:]
746            my_callstack.extend([self, field])
747
748            prev_align = my_align
749            my_align = field.type.get_alignment_after(my_align, my_callstack, log)
750            if my_align is None:
751                if log is not None:
752                    log.fail(prev_align, field.field_name, self, my_callstack,
753                        "alignment is incompatible with this field")
754                return None
755            else:
756                if log is not None:
757                    log.ok(prev_align, field.field_name, self, my_callstack, my_align)
758
759        if log is not None:
760            my_callstack = callstack[:]
761            my_callstack.append(self)
762            log.ok(start_align, "", self, my_callstack, my_align)
763        return my_align
764
765
766class SwitchType(ComplexType):
767    '''
768    Derived class which represents a List of Items.
769
770    Public fields added:
771    bitcases is an array of Bitcase objects describing the list items
772    '''
773
774    def __init__(self, name, elt, *parents):
775        ComplexType.__init__(self, name, elt)
776        self.parents = parents
777        # FIXME: switch cannot store lenfields, so it should just delegate the parents
778        self.lenfield_parent = list(parents) + [self]
779        # self.fields contains all possible fields collected from the Bitcase objects,
780        # whereas self.items contains the Bitcase objects themselves
781        self.bitcases = []
782
783        self.is_switch = True
784        elts = list(elt)
785        self.expr = Expression(elts[0] if len(elts) else elt, self)
786
787    def resolve(self, module):
788        if self.resolved:
789            return
790
791        parents = list(self.parents) + [self]
792
793        # Resolve all of our field datatypes.
794        for index, child in enumerate(list(self.elt)):
795            if child.tag == 'bitcase' or child.tag == 'case':
796                field_name = child.get('name')
797                if field_name is None:
798                    field_type = self.name + ('%s%d' % ( child.tag, index ),)
799                else:
800                    field_type = self.name + (field_name,)
801
802                # use self.parent to indicate anchestor,
803                # as switch does not contain named fields itself
804                if child.tag == 'bitcase':
805                    type = BitcaseType(index, field_type, child, *parents)
806                else:
807                    type = CaseType(index, field_type, child, *parents)
808
809                # construct the switch type name from the parent type and the field name
810                if field_name is None:
811                    type.has_name = False
812                    # Get the full type name for the field
813                    field_type = type.name
814                visible = True
815
816                # add the field to ourself
817                type.make_member_of(module, self, field_type, field_name, visible, True, False)
818
819                # recursively resolve the type (could be another structure, list)
820                type.resolve(module)
821                inserted = False
822                for new_field in type.fields:
823                    # We dump the _placeholder_byte if any fields are added.
824                    for (idx, field) in enumerate(self.fields):
825                        if field == _placeholder_byte:
826                            self.fields[idx] = new_field
827                            inserted = True
828                            break
829                    if False == inserted:
830                        self.fields.append(new_field)
831
832        self.calc_size() # Figure out how big we are
833        self.calc_or_check_required_start_align()
834        self.resolved = True
835
836    def make_member_of(self, module, complex_type, field_type, field_name, visible, wire, auto, enum=None):
837        if not self.fixed_size():
838            # We need a length field.
839            # Ask our Expression object for it's name, type, and whether it's on the wire.
840            lenfid = self.expr.lenfield_type
841            lenfield_name = self.expr.lenfield_name
842            lenwire = self.expr.lenwire
843            needlen = True
844
845            # See if the length field is already in the structure.
846            for parent in self.parents:
847                for field in parent.fields:
848                    if field.field_name == lenfield_name:
849                        needlen = False
850
851            # It isn't, so we need to add it to the structure ourself.
852            if needlen:
853                type = module.get_type(lenfid)
854                lenfield_type = module.get_type_name(lenfid)
855                type.make_member_of(module, complex_type, lenfield_type, lenfield_name, True, lenwire, False, enum)
856
857        # Add ourself to the structure by calling our original method.
858        Type.make_member_of(self, module, complex_type, field_type, field_name, visible, wire, auto, enum)
859
860    # size for switch can only be calculated at runtime
861    def calc_size(self):
862        pass
863
864    # note: switch is _always_ of variable size, but we indicate here wether
865    # it contains elements that are variable-sized themselves
866    def fixed_size(self):
867        return False
868#        for m in self.fields:
869#            if not m.type.fixed_size():
870#                return False
871#        return True
872
873
874
875    def check_implicit_fixed_size_part_aligns(self):
876        # this is done for the CaseType or BitCaseType
877        return
878
879    def unchecked_get_alignment_after(self, start_align, callstack, log):
880        # we assume that BitCases can appear in any combination,
881        # and that at most one Case can appear
882        # (assuming that Cases are mutually exclusive)
883
884        # get all Cases (we assume that at least one case is selected if there are cases)
885        case_fields = []
886        for field in self.bitcases:
887            if field.type.is_case:
888                case_fields.append(field)
889
890        if not case_fields:
891            # there are no case-fields -> check without case-fields
892            case_fields = [None]
893
894        my_callstack = callstack[:]
895        my_callstack.append(self)
896        #
897        total_align = None
898        first = True
899        for case_field in case_fields:
900            my2_callstack = my_callstack[:]
901            if case_field is not None:
902                my2_callstack.append(case_field)
903
904            case_align = self.get_align_for_selected_case_field(
905                             case_field, start_align, my2_callstack, log)
906
907
908            if case_align is None:
909                if log is not None:
910                    if case_field is None:
911                        log.fail(start_align, "", self, my2_callstack,
912                            "alignment without cases (only bitcases) failed")
913                    else:
914                        log.fail(start_align, "", self, my2_callstack + [case_field],
915                            "alignment for selected case %s failed"
916                            % case_field.field_name)
917                return None
918            if first:
919                total_align = case_align
920            else:
921                total_align = total_align.combine_with(case_align)
922
923            if log is not None:
924                if case_field is None:
925                    log.ok(
926                        start_align,
927                        "without cases (only arbitrary bitcases)",
928                        self, my2_callstack, case_align)
929                else:
930                    log.ok(
931                        start_align,
932                        "case %s and arbitrary bitcases" % case_field.field_name,
933                        self, my2_callstack, case_align)
934
935
936        if log is not None:
937            log.ok(start_align, "", self, my_callstack, total_align)
938        return total_align
939
940    # aux function for unchecked_get_alignment_after
941    def get_align_for_selected_case_field(self, case_field, start_align, callstack, log):
942        if verbose_align_log:
943            print ("get_align_for_selected_case_field: %s, case_field = %s" % (str(self), str(case_field)))
944        total_align = start_align
945        for field in self.bitcases:
946            my_callstack = callstack[:]
947            my_callstack.append(field)
948
949            if not field.wire:
950                continue
951            if field is case_field:
952                # assume that this field is active -> no combine_with to emulate optional
953                after_field_align = field.type.get_alignment_after(total_align, my_callstack, log)
954
955                if log is not None:
956                    if after_field_align is None:
957                        log.fail(total_align, field.field_name, field.type, my_callstack,
958                            "invalid aligment for this case branch")
959                    else:
960                        log.ok(total_align, field.field_name, field.type, my_callstack,
961                            after_field_align)
962
963                total_align = after_field_align
964            elif field.type.is_bitcase:
965                after_field_align = field.type.get_alignment_after(total_align, my_callstack, log)
966                # we assume that this field is optional, therefore combine
967                # alignment after the field with the alignment before the field.
968                if after_field_align is None:
969                    if log is not None:
970                        log.fail(total_align, field.field_name, field.type, my_callstack,
971                            "invalid aligment for this bitcase branch")
972                    total_align = None
973                else:
974                    if log is not None:
975                        log.ok(total_align, field.field_name, field.type, my_callstack,
976                            after_field_align)
977
978                    # combine with the align before the field because
979                    # the field is optional
980                    total_align = total_align.combine_with(after_field_align)
981            else:
982                # ignore other fields as they are irrelevant for alignment
983                continue
984
985            if total_align is None:
986                break
987
988        return total_align
989
990
991class Struct(ComplexType):
992    '''
993    Derived class representing a struct data type.
994    '''
995    out = __main__.output['struct']
996
997
998class Union(ComplexType):
999    '''
1000    Derived class representing a union data type.
1001    '''
1002    def __init__(self, name, elt):
1003        ComplexType.__init__(self, name, elt)
1004        self.is_union = True
1005
1006    out = __main__.output['union']
1007
1008
1009    def calc_size(self):
1010        self.size = 0
1011        for m in self.fields:
1012            if not m.wire:
1013                continue
1014            if m.type.fixed_size():
1015                self.size = max(self.size, m.type.get_total_size())
1016            else:
1017                self.size = None
1018                break
1019
1020
1021    def check_implicit_fixed_size_part_aligns(self):
1022        # a union does not have implicit aligns because all fields start
1023        # at the start of the union
1024        return
1025
1026
1027    def unchecked_get_alignment_after(self, start_align, callstack, log):
1028        my_callstack = callstack[:]
1029        my_callstack.append(self)
1030
1031        after_align = None
1032        if self.fixed_size():
1033
1034            #check proper alignment for all members
1035            start_align_ok = all(
1036                [field.type.is_acceptable_start_align(start_align, my_callstack + [field], log)
1037                for field in self.fields])
1038
1039            if start_align_ok:
1040                #compute the after align from the start_align
1041                after_align = start_align.align_after_fixed_size(self.get_total_size())
1042            else:
1043                after_align = None
1044
1045            if log is not None and after_align is not None:
1046                log.ok(start_align, "fixed sized union", self, my_callstack, after_align)
1047
1048        else:
1049            if start_align is None:
1050                if log is not None:
1051                    log.fail(start_align, "", self, my_callstack,
1052                        "missing start_align for union")
1053                return None
1054
1055            after_align = reduce(
1056                lambda x, y: None if x is None or y is None else x.combine_with(y),
1057                [field.type.get_alignment_after(start_align, my_callstack + [field], log)
1058                 for field in self.fields])
1059
1060            if log is not None and after_align is not None:
1061                log.ok(start_align, "var sized union", self, my_callstack, after_align)
1062
1063
1064        if after_align is None and log is not None:
1065            log.fail(start_align, "", self, my_callstack, "start_align is not ok for all members")
1066
1067        return after_align
1068
1069class CaseOrBitcaseType(ComplexType):
1070    '''
1071    Derived class representing a case or bitcase.
1072    '''
1073    def __init__(self, index, name, elt, *parent):
1074        elts = list(elt)
1075        self.expr = []
1076        for sub_elt in elts:
1077            if sub_elt.tag == 'enumref':
1078                self.expr.append(Expression(sub_elt, self))
1079                elt.remove(sub_elt)
1080        ComplexType.__init__(self, name, elt)
1081        self.has_name = True
1082        self.index = 1
1083        self.lenfield_parent = list(parent) + [self]
1084        self.parents = list(parent)
1085        self.is_case_or_bitcase = True
1086
1087    def make_member_of(self, module, switch_type, field_type, field_name, visible, wire, auto, enum=None):
1088        '''
1089        register BitcaseType with the corresponding SwitchType
1090
1091        module is the global module object.
1092        complex_type is the structure object.
1093        see Field for the meaning of the other parameters.
1094        '''
1095        new_field = Field(self, field_type, field_name, visible, wire, auto, enum)
1096
1097        # We dump the _placeholder_byte if any bitcases are added.
1098        for (idx, field) in enumerate(switch_type.bitcases):
1099            if field == _placeholder_byte:
1100                switch_type.bitcases[idx] = new_field
1101                return
1102
1103        switch_type.bitcases.append(new_field)
1104
1105    def resolve(self, module):
1106        if self.resolved:
1107            return
1108
1109        for e in self.expr:
1110            e.resolve(module, self.parents+[self])
1111
1112        # Resolve the bitcase expression
1113        ComplexType.resolve(self, module)
1114
1115        #calculate alignment
1116        self.calc_or_check_required_start_align()
1117
1118
1119class BitcaseType(CaseOrBitcaseType):
1120    '''
1121    Derived class representing a bitcase.
1122    '''
1123    def __init__(self, index, name, elt, *parent):
1124        CaseOrBitcaseType.__init__(self, index, name, elt, *parent)
1125        self.is_bitcase = True
1126
1127class CaseType(CaseOrBitcaseType):
1128    '''
1129    Derived class representing a case.
1130    '''
1131    def __init__(self, index, name, elt, *parent):
1132        CaseOrBitcaseType.__init__(self, index, name, elt, *parent)
1133        self.is_case = True
1134
1135
1136class Reply(ComplexType):
1137    '''
1138    Derived class representing a reply.  Only found as a field of Request.
1139    '''
1140    def __init__(self, name, elt):
1141        ComplexType.__init__(self, name, elt)
1142        self.is_reply = True
1143        self.doc = None
1144        if self.required_start_align is None:
1145            self.required_start_align = Alignment(4,0)
1146
1147        for child in list(elt):
1148            if child.tag == 'doc':
1149                self.doc = Doc(name, child)
1150
1151    def resolve(self, module):
1152        if self.resolved:
1153            return
1154        # Reset pads count
1155        module.pads = 0
1156        # Add the automatic protocol fields
1157        self.fields.append(Field(tcard8, tcard8.name, 'response_type', False, True, True))
1158        self.fields.append(_placeholder_byte)
1159        self.fields.append(Field(tcard16, tcard16.name, 'sequence', False, True, True))
1160        self.fields.append(Field(tcard32, tcard32.name, 'length', False, True, True))
1161        ComplexType.resolve(self, module)
1162
1163
1164class Request(ComplexType):
1165    '''
1166    Derived class representing a request.
1167
1168    Public fields added:
1169    reply contains the reply datatype or None for void requests.
1170    opcode contains the request number.
1171    '''
1172    def __init__(self, name, elt):
1173        ComplexType.__init__(self, name, elt)
1174        self.reply = None
1175        self.doc = None
1176        self.opcode = elt.get('opcode')
1177        if self.required_start_align is None:
1178            self.required_start_align = Alignment(4,0)
1179
1180        for child in list(elt):
1181            if child.tag == 'reply':
1182                self.reply = Reply(name, child)
1183            if child.tag == 'doc':
1184                self.doc = Doc(name, child)
1185
1186    def resolve(self, module):
1187        if self.resolved:
1188            return
1189        # Add the automatic protocol fields
1190        if module.namespace.is_ext:
1191            self.fields.append(Field(tcard8, tcard8.name, 'major_opcode', False, True, True))
1192            self.fields.append(Field(tcard8, tcard8.name, 'minor_opcode', False, True, True))
1193            self.fields.append(Field(tcard16, tcard16.name, 'length', False, True, True))
1194            ComplexType.resolve(self, module)
1195        else:
1196            self.fields.append(Field(tcard8, tcard8.name, 'major_opcode', False, True, True))
1197            self.fields.append(_placeholder_byte)
1198            self.fields.append(Field(tcard16, tcard16.name, 'length', False, True, True))
1199            ComplexType.resolve(self, module)
1200
1201        if self.reply:
1202            self.reply.resolve(module)
1203
1204    out = __main__.output['request']
1205
1206
1207class EventStructAllowedRule:
1208
1209    def __init__(self, parent, elt):
1210        self.elt = elt
1211        self.extension = elt.get('extension')
1212        self.ge_events = elt.get('xge') == "true"
1213        self.min_opcode = int( elt.get('opcode-min') )
1214        self.max_opcode = int( elt.get('opcode-max') )
1215
1216    def resolve(self, parent, module):
1217        # get the namespace of the specified extension
1218        extension_namespace = module.get_namespace( self.extension )
1219        if extension_namespace is None:
1220            raise Exception( "EventStructAllowedRule.resolve: cannot find extension \"" + self.extension + "\"" )
1221            return
1222
1223        # find and add the selected events
1224        for opcode in range(self.min_opcode, self.max_opcode):
1225            name_and_event = extension_namespace.get_event_by_opcode( opcode, self.ge_events )
1226            if name_and_event is None:
1227                # could not find event -> error handling
1228                if self.ge_events:
1229                    raise Exception("EventStructAllowedRule.resolve: cannot find xge-event with opcode " + str(opcode) + " in extension " + self.extension )
1230                else:
1231                    raise Exception("EventStructAllowedRule.resolve: cannot find oldstyle-event with opcode " + str(opcode) + " in extension " + self.extension )
1232                return
1233
1234            ( name, event ) = name_and_event
1235            # add event to EventStruct
1236            parent.add_event( module, self.extension, opcode, name, event )
1237
1238
1239class EventStruct(Union):
1240    '''
1241    Derived class representing an event-use-as-struct data type.
1242    '''
1243
1244    def __init__(self, name, elt):
1245        Union.__init__(self, name, elt)
1246        self.is_eventstruct = True
1247        self.events = []
1248        self.allowedRules = []
1249        self.contains_ge_events = False
1250        for item in list(elt):
1251            if item.tag == 'allowed':
1252                allowedRule = EventStructAllowedRule(self, item)
1253                self.allowedRules.append( allowedRule )
1254                if allowedRule.ge_events:
1255                    self.contains_ge_events = True
1256
1257    out = __main__.output['eventstruct']
1258
1259    def resolve(self, module):
1260        if self.resolved:
1261            return
1262        for allowedRule in self.allowedRules:
1263            allowedRule.resolve(self, module)
1264        Union.resolve(self,module)
1265        self.resolved = True
1266
1267    # add event. called by resolve
1268    def add_event(self, module, extension, opcode, name, event_type ):
1269        self.events.append( (extension, opcode, name, event_type) )
1270        # Add the field to ourself
1271        event_type.make_member_of(module, self, name, name[-1], True, True, False)
1272        # Recursively resolve the event (could be another structure, list)
1273        event_type.resolve(module)
1274
1275    def fixed_size(self):
1276        is_fixed_size = True
1277        for extension, opcode, name, event in self.events:
1278            if not event.fixed_size():
1279                is_fixed_size = False
1280        return is_fixed_size
1281
1282
1283class Event(ComplexType):
1284    '''
1285    Derived class representing an event data type.
1286
1287    Public fields added:
1288    opcodes is a dictionary of name -> opcode number, for eventcopies.
1289    '''
1290    def __init__(self, name, elt):
1291        ComplexType.__init__(self, name, elt)
1292
1293        if self.required_start_align is None:
1294            self.required_start_align = Alignment(4,0)
1295
1296        self.opcodes = {}
1297
1298        self.has_seq = not bool(elt.get('no-sequence-number'))
1299
1300        self.is_ge_event = bool(elt.get('xge'))
1301
1302        self.is_event = True
1303
1304        self.doc = None
1305        for item in list(elt):
1306            if item.tag == 'doc':
1307                self.doc = Doc(name, item)
1308
1309    def add_opcode(self, opcode, name, main):
1310        self.opcodes[name] = opcode
1311        if main:
1312            self.name = name
1313
1314    def get_name_for_opcode(self, opcode):
1315        for name, my_opcode in self.opcodes.items():
1316            if int(my_opcode) == opcode:
1317                return name
1318        else:
1319            return None
1320
1321    def resolve(self, module):
1322        def add_event_header():
1323            self.fields.append(Field(tcard8, tcard8.name, 'response_type', False, True, True))
1324            if self.has_seq:
1325                self.fields.append(_placeholder_byte)
1326                self.fields.append(Field(tcard16, tcard16.name, 'sequence', False, True, True))
1327
1328        def add_ge_event_header():
1329            self.fields.append(Field(tcard8,  tcard8.name,  'response_type', False, True, True))
1330            self.fields.append(Field(tcard8,  tcard8.name,  'extension', False, True, True))
1331            self.fields.append(Field(tcard16, tcard16.name, 'sequence', False, True, True))
1332            self.fields.append(Field(tcard32, tcard32.name, 'length', False, True, True))
1333            self.fields.append(Field(tcard16, tcard16.name, 'event_type', False, True, True))
1334
1335        if self.resolved:
1336            return
1337
1338        # Add the automatic protocol fields
1339        if self.is_ge_event:
1340            add_ge_event_header()
1341        else:
1342            add_event_header()
1343
1344        ComplexType.resolve(self, module)
1345
1346    out = __main__.output['event']
1347
1348
1349class Error(ComplexType):
1350    '''
1351    Derived class representing an error data type.
1352
1353    Public fields added:
1354    opcodes is a dictionary of name -> opcode number, for errorcopies.
1355    '''
1356    def __init__(self, name, elt):
1357        ComplexType.__init__(self, name, elt)
1358        self.opcodes = {}
1359        if self.required_start_align is None:
1360            self.required_start_align = Alignment(4,0)
1361
1362        # All errors are basically the same, but they still got different XML
1363        # for historic reasons. This 'invents' the missing parts.
1364        if len(self.elt) < 1:
1365            SubElement(self.elt, "field", type="CARD32", name="bad_value")
1366        if len(self.elt) < 2:
1367            SubElement(self.elt, "field", type="CARD16", name="minor_opcode")
1368        if len(self.elt) < 3:
1369            SubElement(self.elt, "field", type="CARD8", name="major_opcode")
1370
1371    def add_opcode(self, opcode, name, main):
1372        self.opcodes[name] = opcode
1373        if main:
1374            self.name = name
1375
1376    def resolve(self, module):
1377        if self.resolved:
1378            return
1379
1380        # Add the automatic protocol fields
1381        self.fields.append(Field(tcard8, tcard8.name, 'response_type', False, True, True))
1382        self.fields.append(Field(tcard8, tcard8.name, 'error_code', False, True, True))
1383        self.fields.append(Field(tcard16, tcard16.name, 'sequence', False, True, True))
1384        ComplexType.resolve(self, module)
1385
1386    out = __main__.output['error']
1387
1388
1389class Doc(object):
1390    '''
1391    Class representing a <doc> tag.
1392    '''
1393    def __init__(self, name, elt):
1394        self.name = name
1395        self.description = None
1396        self.brief = 'BRIEF DESCRIPTION MISSING'
1397        self.fields = {}
1398        self.errors = {}
1399        self.see = {}
1400        self.example = None
1401
1402        for child in list(elt):
1403            text = child.text if child.text else ''
1404            if child.tag == 'description':
1405                self.description = text.strip()
1406            if child.tag == 'brief':
1407                self.brief = text.strip()
1408            if child.tag == 'field':
1409                self.fields[child.get('name')] = text.strip()
1410            if child.tag == 'error':
1411                self.errors[child.get('type')] = text.strip()
1412            if child.tag == 'see':
1413                self.see[child.get('name')] = child.get('type')
1414            if child.tag == 'example':
1415                self.example = text.strip()
1416
1417
1418
1419_placeholder_byte = Field(PadType(None), tcard8.name, 'pad0', False, True, False)
1420