aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/oschwald/maxminddb-golang/traverse.go
blob: f9b443c0dffc9a3d03a52dc5db68d17cf099e923 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package maxminddb

import "net"

// Internal structure used to keep track of nodes we still need to visit.
type netNode struct {
    ip      net.IP
    bit     uint
    pointer uint
}

// Networks represents a set of subnets that we are iterating over.
type Networks struct {
    reader   *Reader
    nodes    []netNode // Nodes we still have to visit.
    lastNode netNode
    err      error
}

// Networks returns an iterator that can be used to traverse all networks in
// the database.
//
// Please note that a MaxMind DB may map IPv4 networks into several locations
// in in an IPv6 database. This iterator will iterate over all of these
// locations separately.
func (r *Reader) Networks() *Networks {
    s := 4
    if r.Metadata.IPVersion == 6 {
        s = 16
    }
    return &Networks{
        reader: r,
        nodes: []netNode{
            {
                ip: make(net.IP, s),
            },
        },
    }
}

// Next prepares the next network for reading with the Network method. It
// returns true if there is another network to be processed and false if there
// are no more networks or if there is an error.
func (n *Networks) Next() bool {
    for len(n.nodes) > 0 {
        node := n.nodes[len(n.nodes)-1]
        n.nodes = n.nodes[:len(n.nodes)-1]

        for {
            if node.pointer < n.reader.Metadata.NodeCount {
                ipRight := make(net.IP, len(node.ip))
                copy(ipRight, node.ip)
                if len(ipRight) <= int(node.bit>>3) {
                    n.err = newInvalidDatabaseError(
                        "invalid search tree at %v/%v", ipRight, node.bit)
                    return false
                }
                ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8))

                rightPointer, err := n.reader.readNode(node.pointer, 1)
                if err != nil {
                    n.err = err
                    return false
                }

                node.bit++
                n.nodes = append(n.nodes, netNode{
                    pointer: rightPointer,
                    ip:      ipRight,
                    bit:     node.bit,
                })

                node.pointer, err = n.reader.readNode(node.pointer, 0)
                if err != nil {
                    n.err = err
                    return false
                }

            } else if node.pointer > n.reader.Metadata.NodeCount {
                n.lastNode = node
                return true
            } else {
                break
            }
        }
    }

    return false
}

// Network returns the current network or an error if there is a problem
// decoding the data for the network. It takes a pointer to a result value to
// decode the network's data into.
func (n *Networks) Network(result interface{}) (*net.IPNet, error) {
    if err := n.reader.retrieveData(n.lastNode.pointer, result); err != nil {
        return nil, err
    }

    return &net.IPNet{
        IP:   n.lastNode.ip,
        Mask: net.CIDRMask(int(n.lastNode.bit), len(n.lastNode.ip)*8),
    }, nil
}

// Err returns an error, if any, that was encountered during iteration.
func (n *Networks) Err() error {
    return n.err
}