Source file src/compress/bzip2/huffman.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bzip2
     6  
     7  import (
     8  	"cmp"
     9  	"slices"
    10  )
    11  
    12  // A huffmanTree is a binary tree which is navigated, bit-by-bit to reach a
    13  // symbol.
    14  type huffmanTree struct {
    15  	// nodes contains all the non-leaf nodes in the tree. nodes[0] is the
    16  	// root of the tree and nextNode contains the index of the next element
    17  	// of nodes to use when the tree is being constructed.
    18  	nodes    []huffmanNode
    19  	nextNode int
    20  }
    21  
    22  // A huffmanNode is a node in the tree. left and right contain indexes into the
    23  // nodes slice of the tree. If left or right is invalidNodeValue then the child
    24  // is a left node and its value is in leftValue/rightValue.
    25  //
    26  // The symbols are uint16s because bzip2 encodes not only MTF indexes in the
    27  // tree, but also two magic values for run-length encoding and an EOF symbol.
    28  // Thus there are more than 256 possible symbols.
    29  type huffmanNode struct {
    30  	left, right           uint16
    31  	leftValue, rightValue uint16
    32  }
    33  
    34  // invalidNodeValue is an invalid index which marks a leaf node in the tree.
    35  const invalidNodeValue = 0xffff
    36  
    37  // Decode reads bits from the given bitReader and navigates the tree until a
    38  // symbol is found.
    39  func (t *huffmanTree) Decode(br *bitReader) (v uint16) {
    40  	nodeIndex := uint16(0) // node 0 is the root of the tree.
    41  
    42  	for {
    43  		node := &t.nodes[nodeIndex]
    44  
    45  		var bit uint16
    46  		if br.bits > 0 {
    47  			// Get next bit - fast path.
    48  			br.bits--
    49  			bit = uint16(br.n>>(br.bits&63)) & 1
    50  		} else {
    51  			// Get next bit - slow path.
    52  			// Use ReadBits to retrieve a single bit
    53  			// from the underling io.ByteReader.
    54  			bit = uint16(br.ReadBits(1))
    55  		}
    56  
    57  		// Trick a compiler into generating conditional move instead of branch,
    58  		// by making both loads unconditional.
    59  		l, r := node.left, node.right
    60  
    61  		if bit == 1 {
    62  			nodeIndex = l
    63  		} else {
    64  			nodeIndex = r
    65  		}
    66  
    67  		if nodeIndex == invalidNodeValue {
    68  			// We found a leaf. Use the value of bit to decide
    69  			// whether is a left or a right value.
    70  			l, r := node.leftValue, node.rightValue
    71  			if bit == 1 {
    72  				v = l
    73  			} else {
    74  				v = r
    75  			}
    76  			return
    77  		}
    78  	}
    79  }
    80  
    81  // newHuffmanTree builds a Huffman tree from a slice containing the code
    82  // lengths of each symbol. The maximum code length is 32 bits.
    83  func newHuffmanTree(lengths []uint8) (huffmanTree, error) {
    84  	// There are many possible trees that assign the same code length to
    85  	// each symbol (consider reflecting a tree down the middle, for
    86  	// example). Since the code length assignments determine the
    87  	// efficiency of the tree, each of these trees is equally good. In
    88  	// order to minimize the amount of information needed to build a tree
    89  	// bzip2 uses a canonical tree so that it can be reconstructed given
    90  	// only the code length assignments.
    91  
    92  	if len(lengths) < 2 {
    93  		panic("newHuffmanTree: too few symbols")
    94  	}
    95  
    96  	var t huffmanTree
    97  
    98  	// First we sort the code length assignments by ascending code length,
    99  	// using the symbol value to break ties.
   100  	pairs := make([]huffmanSymbolLengthPair, len(lengths))
   101  	for i, length := range lengths {
   102  		pairs[i].value = uint16(i)
   103  		pairs[i].length = length
   104  	}
   105  
   106  	slices.SortFunc(pairs, func(a, b huffmanSymbolLengthPair) int {
   107  		if c := cmp.Compare(a.length, b.length); c != 0 {
   108  			return c
   109  		}
   110  		return cmp.Compare(a.value, b.value)
   111  	})
   112  
   113  	// Now we assign codes to the symbols, starting with the longest code.
   114  	// We keep the codes packed into a uint32, at the most-significant end.
   115  	// So branches are taken from the MSB downwards. This makes it easy to
   116  	// sort them later.
   117  	code := uint32(0)
   118  	length := uint8(32)
   119  
   120  	codes := make([]huffmanCode, len(lengths))
   121  	for i := len(pairs) - 1; i >= 0; i-- {
   122  		if length > pairs[i].length {
   123  			length = pairs[i].length
   124  		}
   125  		codes[i].code = code
   126  		codes[i].codeLen = length
   127  		codes[i].value = pairs[i].value
   128  		// We need to 'increment' the code, which means treating |code|
   129  		// like a |length| bit number.
   130  		code += 1 << (32 - length)
   131  	}
   132  
   133  	// Now we can sort by the code so that the left half of each branch are
   134  	// grouped together, recursively.
   135  	slices.SortFunc(codes, func(a, b huffmanCode) int {
   136  		return cmp.Compare(a.code, b.code)
   137  	})
   138  
   139  	t.nodes = make([]huffmanNode, len(codes))
   140  	_, err := buildHuffmanNode(&t, codes, 0)
   141  	return t, err
   142  }
   143  
   144  // huffmanSymbolLengthPair contains a symbol and its code length.
   145  type huffmanSymbolLengthPair struct {
   146  	value  uint16
   147  	length uint8
   148  }
   149  
   150  // huffmanCode contains a symbol, its code and code length.
   151  type huffmanCode struct {
   152  	code    uint32
   153  	codeLen uint8
   154  	value   uint16
   155  }
   156  
   157  // buildHuffmanNode takes a slice of sorted huffmanCodes and builds a node in
   158  // the Huffman tree at the given level. It returns the index of the newly
   159  // constructed node.
   160  func buildHuffmanNode(t *huffmanTree, codes []huffmanCode, level uint32) (nodeIndex uint16, err error) {
   161  	test := uint32(1) << (31 - level)
   162  
   163  	// We have to search the list of codes to find the divide between the left and right sides.
   164  	firstRightIndex := len(codes)
   165  	for i, code := range codes {
   166  		if code.code&test != 0 {
   167  			firstRightIndex = i
   168  			break
   169  		}
   170  	}
   171  
   172  	left := codes[:firstRightIndex]
   173  	right := codes[firstRightIndex:]
   174  
   175  	if len(left) == 0 || len(right) == 0 {
   176  		// There is a superfluous level in the Huffman tree indicating
   177  		// a bug in the encoder. However, this bug has been observed in
   178  		// the wild so we handle it.
   179  
   180  		// If this function was called recursively then we know that
   181  		// len(codes) >= 2 because, otherwise, we would have hit the
   182  		// "leaf node" case, below, and not recurred.
   183  		//
   184  		// However, for the initial call it's possible that len(codes)
   185  		// is zero or one. Both cases are invalid because a zero length
   186  		// tree cannot encode anything and a length-1 tree can only
   187  		// encode EOF and so is superfluous. We reject both.
   188  		if len(codes) < 2 {
   189  			return 0, StructuralError("empty Huffman tree")
   190  		}
   191  
   192  		// In this case the recursion doesn't always reduce the length
   193  		// of codes so we need to ensure termination via another
   194  		// mechanism.
   195  		if level == 31 {
   196  			// Since len(codes) >= 2 the only way that the values
   197  			// can match at all 32 bits is if they are equal, which
   198  			// is invalid. This ensures that we never enter
   199  			// infinite recursion.
   200  			return 0, StructuralError("equal symbols in Huffman tree")
   201  		}
   202  
   203  		if len(left) == 0 {
   204  			return buildHuffmanNode(t, right, level+1)
   205  		}
   206  		return buildHuffmanNode(t, left, level+1)
   207  	}
   208  
   209  	nodeIndex = uint16(t.nextNode)
   210  	node := &t.nodes[t.nextNode]
   211  	t.nextNode++
   212  
   213  	if len(left) == 1 {
   214  		// leaf node
   215  		node.left = invalidNodeValue
   216  		node.leftValue = left[0].value
   217  	} else {
   218  		node.left, err = buildHuffmanNode(t, left, level+1)
   219  	}
   220  
   221  	if err != nil {
   222  		return
   223  	}
   224  
   225  	if len(right) == 1 {
   226  		// leaf node
   227  		node.right = invalidNodeValue
   228  		node.rightValue = right[0].value
   229  	} else {
   230  		node.right, err = buildHuffmanNode(t, right, level+1)
   231  	}
   232  
   233  	return
   234  }
   235  

View as plain text