# Written by Bram Cohen
# Modified by Cameron Dale
# see LICENSE.txt for license information
#
# $Id: SocketHandler.py 364 2008-01-27 22:38:02Z camrdale-guest $

"""Handle all sockets.

@type logger: C{logging.Logger}
@var logger: the logger to send all log messages to for this module
@type all: C{int}
@var all: all events to check for, both input and output

"""

import socket
from errno import EWOULDBLOCK, ECONNREFUSED, EHOSTUNREACH
try:
    from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
    timemult = 1000
except ImportError:
    from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
    timemult = 1
from time import sleep
from clock import clock
import sys, logging
from random import shuffle, randrange
# from BT1.StreamCheck import StreamCheck
# import inspect

logger = logging.getLogger('DebTorrent.SocketHandler')

all = POLLIN | POLLOUT

class SingleSocket:
    """Manage a single socket.
    
    @type socket_handler: L{SocketHandler}
    @ivar socket_handler: the collection of all sockets
    @type socket: C{socket.socket}
    @ivar socket: the socket to manage
    @type handler: unknown
    @ivar handler: the handler to use for all communications on the socket
    @type buffer: C{list} of C{string}
    @ivar buffer: the list of data waiting to be written on the socket
    @type last_hit: C{float}
    @ivar last_hit: the last time data was received on the socket
    @type fileno: C{int}
    @ivar fileno: the file number of the socket
    @type connected: C{boolean}
    @ivar connected: whether this socket has received an event yet
    @type skipped: C{int}
    @ivar skipped: the number of consecutive writes to the socket that have failed
    @type dns: (C{string}, C{int})
    @ivar dns: the IP address and port to use if one can't be obtained from the socket
    
    """
    
    def __init__(self, socket_handler, sock, handler, dns = None):
        """
        
        @type socket_handler: L{SocketHandler}
        @param socket_handler: the collection of all sockets
        @type sock: C{socket.socket}
        @param sock: the socket to manage
        @type handler: unknown
        @param handler: the handler to use for all communications on the socket
        @type dns: (C{string}, C{int})
        @param dns: the IP address and port to use if one can't be obtained
            from the socket (optional, defaults to 'unknown')
        
        """
        
        self.socket_handler = socket_handler
        self.socket = sock
        self.handler = handler
        self.buffer = []
        self.last_hit = clock()
        self.fileno = sock.fileno()
        self.connected = False
        self.skipped = 0
#        self.check = StreamCheck()
        try:
            self.dns = self.socket.getpeername()
        except:
            if dns is None:
                self.dns = ('unknown', 0)
            else:
                self.dns = dns
        logger.debug('new socket: %r', self.dns)
        
    def get_ip(self, real=False):
        """Get the IP address of the socket.
        
        @type real: C{boolean}
        @param real: whether to try and get the IP address directly from the 
            socket or trust the one supplied when the instance was created 
            (optional, defaults to False)
        @rtype: C{string}
        @return: the IP address
        
        """
        
        if real:
            try:
                self.dns = self.socket.getpeername()
            except:
                pass
        return self.dns[0]
        
    def getpeername(self, real=False):
        """Get the IP address and port of the socket.
        
        @type real: C{boolean}
        @param real: whether to try and get the IP address directly from the 
            socket or trust the one supplied when the instance was created 
            (optional, defaults to False)
        @rtype: (C{string}, C{int})
        @return: the IP address and port of the remote connection
        
        """
        
        if real:
            try:
                self.dns = self.socket.getpeername()
            except:
                pass
        return self.dns
        
    def close(self):
        """Close the socket."""
        assert self.socket
        logger.debug('close socket: %r', self.dns)
        self.connected = False
        sock = self.socket
        self.socket = None
        self.buffer = []
        del self.socket_handler.single_sockets[self.fileno]
        self.socket_handler.poll.unregister(sock)
        sock.close()

    def shutdown(self, val):
        """Shutdown the socket.
        
        @type val: C{int}
        @param val: the type of event to shutdown the socket for.
            0 = reading, 1 = writing, 2 = reading and writing.
        
        """
        
        logger.debug('socket %r shutdown:'+str(val), self.dns)
        self.socket.shutdown(val)

    def is_flushed(self):
        """Check if the socket is flushed (no data is waiting to be sent).
        
        @rtype: C{boolean}
        @return: whether the socket is flushed
        
        """
        
        return not self.buffer

    def write(self, s):
        """Write data out on the socket.
        
        Adds the data to the buffer of data waiting to be written, then tries 
        to write the waiting data out.
        
        @type s: C{string} or C{file}
        @param s: the data to write, or an already opened file to write out
        
        """
        
#        self.check.write(s)
        assert self.socket is not None
        self.buffer.append(s)
        if len(self.buffer) == 1:
            self.try_write()

    def try_write(self):
        """Try to write waiting data on the socket.
        
        Will try to write all buffered data on the socket. If a send fails,
        the attempt will stop. 3 consecutive failed attempts will cause the
        socket to be declared dead.
        
        """
        if self.connected:
            dead = False
            try:
                while self.buffer:
                    # Read data from the file and put in on the buffer
                    while self.buffer and type(self.buffer[0]) == file:
                        data = self.buffer[0].read(4096)
                        if len(data) > 0:
                            self.buffer.insert(0, data)
                        else:
                            # End of file has been reached
                            self.buffer[0].close()
                            del self.buffer[0]

                    # Make sure there's still data to send
                    if not self.buffer:
                        break
                            
                    buf = self.buffer[0]
                    amount = self.socket.send(buf)
                    if amount == 0:
                        self.skipped += 1
                        break
                    self.skipped = 0
                    if amount != len(buf):
                        self.buffer[0] = buf[amount:]
                        break
                    del self.buffer[0]
            except socket.error, e:
                try:
                    dead = e[0] != EWOULDBLOCK
                except:
                    dead = True
                self.skipped += 1
            if self.skipped >= 3:
                dead = True
            if dead:
                logger.debug('Socket is dead from write: %r', self.dns)
                self.socket_handler.dead_from_write.append(self)
                return
        if self.buffer:
            self.socket_handler.poll.register(self.socket, all)
        else:
            self.socket_handler.poll.register(self.socket, POLLIN)

    def set_handler(self, handler):
        """Set the handler to use for this socket.
        
        @type handler: unknown
        @param handler: the handler to use for all communications on the socket
        
        """
        
        self.handler = handler

class SocketHandler:
    """The collection of all open sockets.
    
    @type timeout: C{float}
    @ivar timeout: seconds to wait between closing sockets on which 
            nothing has been received on
    @type ipv6_enable: C{boolean}
    @ivar ipv6_enable: allow the client to connect to peers via IPv6
    @type readsize: C{int}
    @ivar readsize: the maximum amount of data to read from a socket
    @type poll: C{select.poll}
    @ivar poll: the poll object to use to poll the sockets
    @type single_sockets: C{dictionary} of {C{int}: L{SingleSocket}}
    @ivar single_sockets: the collection of all open sockets, keys are the 
        socket's file number
    @type dead_from_write: C{list} of L{SingleSocket}
    @ivar dead_from_write: the sockets that have failed due to writing
    @type max_connects: C{int}
    @ivar max_connects: the maximum number of sockets to have open at atime
    @type servers: C{dictionary} of {C{int}: C{socket.socket}}
    @ivar servers: the socket listeners, keys are the file numbers
    @type interfaces: C{list} of C{string}
    @ivar interfaces: the interfaces that have been bound to
    @type ports: C{list} of C{int}
    @ivar ports: the ports that are being listened on
    @type handlers: C{dictionary} of {C{int}: unknown}
    @ivar handlers: the handlers that are used for the listened ports, 
        keys are the ports
    
    """
    
    def __init__(self, timeout, ipv6_enable, readsize = 100000):
        """Initialize the instance.
        
        @type timeout: C{float}
        @param timeout: seconds to wait between closing sockets on which 
            nothing has been received on
        @type ipv6_enable: C{boolean}
        @param ipv6_enable: allow the client to connect to peers via IPv6
        @type readsize: C{int}
        @param readsize: the maximum amount of data to read from a socket
            (optional, defaults to 100000)
        
        """
        
        self.timeout = timeout
        self.ipv6_enable = ipv6_enable
        self.readsize = readsize
        self.poll = poll()
        self.single_sockets = {}
        self.dead_from_write = []
        self.max_connects = 1000
        self.servers = {}
        self.interfaces = []
        self.ports = []
        self.handlers = {}

    def scan_for_timeouts(self):
        """Check the sockets for timeouts."""
        t = clock() - self.timeout
        tokill = []
        for s in self.single_sockets.values():
            if s.last_hit < t:
                tokill.append(s)
        for k in tokill:
            if k.socket is not None:
                self._close_socket(k)

    def bind(self, port, bind = '', reuse = False, ipv6_socket_style = 1):
        """Bind to listen on a single port.
        
        @type port: C{int}
        @param port: the port to listen on
        @type bind: C{string}
        @param bind: the IP address to bind to (optional, defaults to all)
        @type reuse: C{boolean}
        @param reuse: whether to use SO_REUSEADDR to bind (optional, defaults 
            to False). This allows the bind to work if the socket is still
            open in the TIME_WAIT state from a recently shutdown server.
        @type ipv6_socket_style: C{int}
        @param ipv6_socket_style: whether an IPv6 server socket will also 
            field IPv4 connections (optional, defaults to yes)
        @raise socket.error: if the port can not be bound
        
        """
        

        port = int(port)
        addrinfos = []
        # Don't reinitialize to allow multiple binds
        newservers = {}
        newinterfaces = []
        # if bind != "" thread it as a comma seperated list and bind to all
        # addresses (can be ips or hostnames) else bind to default ipv6 and
        # ipv4 address
        if bind:
            if self.ipv6_enable:
                socktype = socket.AF_UNSPEC
            else:
                socktype = socket.AF_INET
            bind = bind.split(',')
            for addr in bind:
                addrinfos.extend(socket.getaddrinfo(addr, port,
                                               socktype, socket.SOCK_STREAM))
        else:
            if self.ipv6_enable:
                addrinfos.append([socket.AF_INET6, None, None, None, ('', port)])
            if not addrinfos or ipv6_socket_style != 0:
                addrinfos.append([socket.AF_INET, None, None, None, ('', port)])
        for addrinfo in addrinfos:
            try:
                server = socket.socket(addrinfo[0], socket.SOCK_STREAM)
                if reuse:
                    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                server.setblocking(0)
                server.bind(addrinfo[4])
                newservers[server.fileno()] = server
                if bind:
                    newinterfaces.append(server.getsockname()[0])
                server.listen(64)
                self.poll.register(server, POLLIN)
            except socket.error, e:
                for server in newservers.values():
                    try:
                        server.close()
                    except:
                        pass
                if self.ipv6_enable and ipv6_socket_style == 0 and newservers:
                    raise socket.error('blocked port (may require ipv6_binds_v4 to be set)')
                raise socket.error(str(e))
        if not newservers:
            raise socket.error('unable to open server port')
        self.ports.append(port)
        # Save the newly created items
        for key,value in newservers.items():
            self.servers[key] = value
        for item in newinterfaces:
            self.interfaces.append(item)
        logger.info('Successfully bound to port '+str(port))

    def find_and_bind(self, minport, maxport, bind = '', reuse = False,
                      ipv6_socket_style = 1, randomizer = False):
        """Bind to listen on a single port within a range.
        
        @type minport: C{int}
        @param minport: the minimum port to listen on
        @type maxport: C{int}
        @param maxport: the maximum port to listen on
        @type bind: C{string}
        @param bind: the addresses to bind to (optional, defaults to the 
            default IPv6 and IPv4 addreses). Parsed as a comma seperated 
            list (can be IP addresses or hostnames).
        @type reuse: C{boolean}
        @param reuse: whether to use SO_REUSEADDR to bind (optional, defaults 
            to False). This allows the bind to work if the socket is still
            open in the TIME_WAIT state from a recently shutdown server.
        @type ipv6_socket_style: C{int}
        @param ipv6_socket_style: whether an IPv6 server socket will also 
            field IPv4 connections (optional, defaults to yes)
        @type randomizer: C{boolean}
        @param randomizer: whether to randomize the range or use it sequentially
        @rtype: C{int}
        @return: the port that was bound to
        @raise socket.error: if none of the ports in the range can be bound
        
        """

        e = 'maxport less than minport - no ports to check'
        if maxport-minport < 50 or not randomizer:
            portrange = range(minport, maxport+1)
            if randomizer:
                shuffle(portrange)
                portrange = portrange[:20]  # check a maximum of 20 ports
        else:
            portrange = []
            while len(portrange) < 20:
                listen_port = randrange(minport, maxport+1)
                if not listen_port in portrange:
                    portrange.append(listen_port)
        for listen_port in portrange:
            try:
                self.bind(listen_port, bind,
                               ipv6_socket_style = ipv6_socket_style)
                return listen_port
            except socket.error, e:
                pass
        raise socket.error(str(e))


    def set_handler(self, handler, port = None):
        """Set the handler to use for a port (or the default handler).
        
        @type handler: unknown
        @param handler: the data handler to use to process data on the port
        @type port: C{int}
        @param port: the port to use the handler for
            (optional, defaults to setting the default handler)
        
        """
        
        if port is None:
            self.handler = handler
        else:
            self.handlers[port] = handler


    def start_connection_raw(self, dns, socktype = socket.AF_INET, handler = None):
        """Initiate a new connection to a peer (setting the type of socket).
        
        @type dns: (C{string}, C{int})
        @param dns: the IP address and port number to contact the peer on
        @type socktype: C{int}
        @param socktype: the type of socket to open
        @type handler: unknown
        @param handler: the data handler to use to process data on the connection
            (optional, defaults to using the defualt handler)
        @rtype: L{SocketHandler.SingleSocket}
        @return: the new connection made to the peer
        @raise socket.error: if the connection fails
        
        """
        
        if handler is None:
            handler = self.handler
        sock = socket.socket(socktype, socket.SOCK_STREAM)
        sock.setblocking(0)
        try:
            sock.connect_ex(dns)
        except socket.error:
            raise
        except Exception, e:
            raise socket.error(str(e))
        self.poll.register(sock, POLLIN)
        s = SingleSocket(self, sock, handler, dns)
        self.single_sockets[sock.fileno()] = s
        return s


    def start_connection(self, dns, handler = None, randomize = False):
        """Initiate a new connection to a peer.
        
        @type dns: (C{string}, C{int})
        @param dns: the IP address and port number to contact the peer on
        @type handler: unknown
        @param handler: the data handler to use to process data on the connection
            (optional, defaults to using the defualt handler)
        @type randomize: C{boolean}
        @param randomize: whether to randomize the possible sockets or 
            choose one sequentially
        @rtype: L{SocketHandler.SingleSocket}
        @return: the new connection made to the peer
        @raise socket.error: if the connection fails
        
        """
        
        if handler is None:
            handler = self.handler
        if self.ipv6_enable:
            socktype = socket.AF_UNSPEC
        else:
            socktype = socket.AF_INET
        try:
            addrinfos = socket.getaddrinfo(dns[0], int(dns[1]),
                                           socktype, socket.SOCK_STREAM)
        except socket.error, e:
            raise
        except Exception, e:
            raise socket.error(str(e))
        if randomize:
            shuffle(addrinfos)
        for addrinfo in addrinfos:
            try:
                s = self.start_connection_raw(addrinfo[4],addrinfo[0],handler)
                break
            except:
                pass
        else:
            raise socket.error('unable to connect')
        return s


    def _sleep(self):
        """Sleep for one second."""
        sleep(1)
        
    def handle_events(self, events):
        """Handle any events that have occurred on the open sockets.
        
        @type events: C{list} of (C{int}, C{int})
        @param events: the socket file descriptors and event types that have occurred on them
        
        """
        
        for sock, event in events:
            s = self.servers.get(sock)
            if s:
                if event & (POLLHUP | POLLERR) != 0:
                    self.poll.unregister(s)
                    s.close()
                    del self.servers[sock]
                    logger.error("lost server socket")
                elif len(self.single_sockets) < self.max_connects:
                    try:
                        port = s.getsockname()[1]
                        handler = self.handlers.get(port, self.handler)
                        newsock, addr = s.accept()
                        newsock.setblocking(0)
                        nss = SingleSocket(self, newsock, handler)
                        self.single_sockets[newsock.fileno()] = nss
                        self.poll.register(newsock, POLLIN)
                        handler.external_connection_made(nss)
                    except socket.error:
                        self._sleep()
            else:
                s = self.single_sockets.get(sock)
                if not s:
                    continue
                s.connected = True
                if (event & (POLLHUP | POLLERR)):
                    self._close_socket(s)
                    continue
                if (event & POLLIN):
                    try:
                        s.last_hit = clock()
                        data = s.socket.recv(self.readsize)
                        if not data:
                            self._close_socket(s)
                        else:
                            s.handler.data_came_in(s, data)
                    except socket.error, e:
                        code, msg = e
                        if code != EWOULDBLOCK:
                            self._close_socket(s)
                            continue
                if (event & POLLOUT) and s.socket and not s.is_flushed():
                    s.last_hit = clock()
                    s.try_write()
                    if s.is_flushed():
                        s.handler.connection_flushed(s)

    def close_dead(self):
        """Close sockets that have failed to be written."""
        while self.dead_from_write:
            old = self.dead_from_write
            self.dead_from_write = []
            for s in old:
                if s.socket:
                    self._close_socket(s)

    def _close_socket(self, s):
        """Close an open socket.
        
        @type s: L{SingleSocket}
        @param s: the socket to close
        
        """
        
        s.close()
        s.handler.connection_lost(s)

    def do_poll(self, t):
        """Poll the open sockets.
        
        If the poll returns None (which it shouldn't), then 5% of the open 
        sockets will be closed.
        
        @type t: C{float}
        @param t: seconds to wait for an event before timing out
        
        """
        
        r = self.poll.poll(t*timemult)
        if r is None:
            connects = len(self.single_sockets)
            to_close = int(connects*0.05)+1 # close 5% of sockets
            self.max_connects = connects-to_close
            closelist = self.single_sockets.values()
            shuffle(closelist)
            closelist = closelist[:to_close]
            for sock in closelist:
                self._close_socket(sock)
            return []
        return r     

    def get_stats(self):
        """Get some information about the bound interfaces and ports.
        
        @rtype: C{dictionary}
        @return: info about the bound interfaces
        
        """
        
        return { 'interfaces': self.interfaces,
                 'port': self.ports }


    def shutdown(self):
        """Close all open sockets and servers."""
        for ss in self.single_sockets.values():
            try:
                ss.close()
            except:
                pass
        for server in self.servers.values():
            try:
                server.close()
            except:
                pass

