aboutsummaryrefslogblamecommitdiffstats
path: root/p2p/network.go
blob: 820cef1a91835e9f3d98e62c1be628bbf9cfee6c (plain) (tree)



































































































































































































                                                                                                                                            
package p2p

import (
    "fmt"
    "math/rand"
    "net"
    "strconv"
    "time"
)

const (
    DialerTimeout             = 180 //seconds
    KeepAlivePeriod           = 60  //minutes
    portMappingUpdateInterval = 900 // seconds = 15 mins
    upnpDiscoverAttempts      = 3
)

// Dialer is not an interface in net, so we define one
// *net.Dialer conforms to this
type Dialer interface {
    Dial(network, address string) (net.Conn, error)
}

type Network interface {
    Start() error
    Listener(net.Addr) (net.Listener, error)
    Dialer(net.Addr) (Dialer, error)
    NewAddr(string, int) (addr net.Addr, err error)
    ParseAddr(string) (addr net.Addr, err error)
}

type NAT interface {
    GetExternalAddress() (addr net.IP, err error)
    AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
    DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
}

type TCPNetwork struct {
    nat     NAT
    natType NATType
    quit    chan chan bool
    ports   chan string
}

type NATType int

const (
    NONE = iota
    UPNP
    PMP
)

const (
    portMappingTimeout = 1200 // 20 mins
)

func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
    return &TCPNetwork{
        natType: natType,
        ports:   make(chan string),
    }
}

func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
    return &net.Dialer{
        Timeout: DialerTimeout * time.Second,
        // KeepAlive: KeepAlivePeriod * time.Minute,
        LocalAddr: addr,
    }, nil
}

func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
    if self.natType == UPNP {
        _, port, _ := net.SplitHostPort(addr.String())
        if self.quit == nil {
            self.quit = make(chan chan bool)
            go self.updatePortMappings()
        }
        self.ports <- port
    }
    return net.Listen(addr.Network(), addr.String())
}

func (self *TCPNetwork) Start() (err error) {
    switch self.natType {
    case NONE:
    case UPNP:
        nat, uerr := upnpDiscover(upnpDiscoverAttempts)
        if uerr != nil {
            err = fmt.Errorf("UPNP failed: ", uerr)
        } else {
            self.nat = nat
        }
    case PMP:
        err = fmt.Errorf("PMP not implemented")
    default:
        err = fmt.Errorf("Invalid NAT type: %v", self.natType)
    }
    return
}

func (self *TCPNetwork) Stop() {
    q := make(chan bool)
    self.quit <- q
    <-q
}

func (self *TCPNetwork) addPortMapping(lport int) (err error) {
    _, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
    if err != nil {
        logger.Errorf("unable to add port mapping on %v: %v", lport, err)
    } else {
        logger.Debugf("succesfully added port mapping on %v", lport)
    }
    return
}

func (self *TCPNetwork) updatePortMappings() {
    timer := time.NewTimer(portMappingUpdateInterval * time.Second)
    lports := []int{}
out:
    for {
        select {
        case port := <-self.ports:
            int64lport, _ := strconv.ParseInt(port, 10, 16)
            lport := int(int64lport)
            if err := self.addPortMapping(lport); err != nil {
                lports = append(lports, lport)
            }
        case <-timer.C:
            for lport := range lports {
                if err := self.addPortMapping(lport); err != nil {
                }
            }
        case errc := <-self.quit:
            errc <- true
            break out
        }
    }

    timer.Stop()
    for lport := range lports {
        if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
            logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
        } else {
            logger.Debugf("succesfully removed port mapping on %v", lport)
        }
    }
}

func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
    ip, err := self.lookupIP(host)
    if err == nil {
        return &net.TCPAddr{
            IP:   ip,
            Port: port,
        }, nil
    }
    return nil, err
}

func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
    host, port, err := net.SplitHostPort(address)
    if err == nil {
        iport, _ := strconv.Atoi(port)
        addr, e := self.NewAddr(host, iport)
        return addr, e
    }
    return nil, err
}

func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
    if ip = net.ParseIP(host); ip != nil {
        return
    }

    var ips []net.IP
    ips, err = net.LookupIP(host)
    if err != nil {
        logger.Warnln(err)
        return
    }
    if len(ips) == 0 {
        err = fmt.Errorf("No IP addresses available for %v", host)
        logger.Warnln(err)
        return
    }
    if len(ips) > 1 {
        // Pick a random IP address, simulating round-robin DNS.
        rand.Seed(time.Now().UTC().UnixNano())
        ip = ips[rand.Intn(len(ips))]
    } else {
        ip = ips[0]
    }
    return
}