# Written by Bram Cohen
# Modified by Cameron Dale
# see LICENSE.txt for license information
#
# $Id: Encrypter.py 267 2007-08-18 03:12:43Z camrdale-guest $

"""Make encrypted connections to peers.

@type logger: C{logging.Logger}
@var logger: the logger to send all log messages to for this module
@type MAX_INCOMPLETE: C{int}
@var MAX_INCOMPLETE: the maximum number of incomplete connections to have
    outstanding before new connections to initiate get queued
@type option_pattern: C{string}
@var option_pattern: the supported options to send to all peers
@type incompletecounter: L{IncompleteCounter}
@var incompletecounter: the counter to use to track the number of incomplete
    connections outstanding

"""

from cStringIO import StringIO
from binascii import b2a_hex
import struct
from socket import error as socketerror
from DebTorrent.BTcrypto import Crypto
from DebTorrent.__init__ import protocol_name, make_readable
import logging

logger = logging.getLogger('DebTorrent.BT1.Encrypter')

DEBUG = False

MAX_INCOMPLETE = 8

option_pattern = chr(0)*8


class IncompleteCounter:
    """Keep track of the number of oustanding incomplete connections.
    
    @type c: C{int}
    @ivar c: the number of outstanding incomplete connections
    
    """
    
    def __init__(self):
        """Initialize the counter."""
        self.c = 0
        
    def increment(self):
        """Increment the counter."""
        self.c += 1
        
    def decrement(self):
        """Decrement the counter."""
        self.c -= 1
        
    def toomany(self):
        """Determine if the maximum number of incomplete connections has been reached.
        
        @rtype: C{boolean}
        @return: whether the maximum has been reached
        
        """
        
        return self.c >= MAX_INCOMPLETE
    
incompletecounter = IncompleteCounter()


class Connection:
    """A single, possibly encrypted, connection to a peer.
    
    The BitTorrent handshake is in the order: header, options, download id 
    (info hash), peer id, [length, message], ....

    @type Encoder: L{Encoder}
    @ivar Encoder: the collection of all connections
    @type connection: L{DebTorrent.SocketHandler.SingleSocket}
    @ivar connection: the low-level connection to the peer
    @type connecter: L{Connecter.Connecter}
    @ivar connecter: the Connecter instance to use
    @type dns: (C{string}, C{int})
    @ivar dns: the IP address and port to connect to
    @type id: C{string}
    @ivar id: the peer ID of the peer
    @type locally_initiated: C{boolean}
    @ivar locally_initiated: whether the connectio was initiated locally
    @type readable_id: C{string}
    @ivar readable_id: the human-readable ID of the peer
    @type complete: C{boolean}
    @ivar complete: whether the handshake is complete
    @type keepalive: C{method}
    @ivar keepalive: the method to call to send a keepalive message on the connection
    @type closed: C{boolean}
    @ivar closed: whether the connection has been closed
    @type buffer: C{string}
    @ivar buffer: the buffer of received data from the connection
    @type bufferlen: C{int}
    @ivar bufferlen: the length of the buffer
    @type log: C{file}
    @ivar log: the log file to write to
    @type read: C{method}
    @ivar read: the method to use to read from the connection
    @type write: C{method}
    @ivar write: the method to use to write to the connection
    @type cryptmode: C{int}
    @ivar cryptmode: the type of encryption being used
    @type encrypter: L{DebTorrent.BTcrypto.Crypto}
    @ivar encrypter: the already created Crypto instance, if the connection
            was externally handshaked (optional, defaults to creating a new one)
    @type encrypted: C{boolean}
    @ivar encrypted: whether the connection is encrypted (will be None if
        that is not yet known)
    @type next_len: C{int}
    @ivar next_len: the next amount of data to read from the connection
    @type next_func: C{method}
    @ivar next_func: the next method to use to process incoming data on the 
        connection
    @type options: C{string}
    @ivar options: the options read from the externally handshaked
            connection (optional, defaults to None)
    @type _logwritefunc: C{method}
    @ivar _logwritefunc: the saved write method for intercepting connection
        writes for logging purposes
    @type _max_search: C{int}
    @ivar _max_search: the number of remaining bytes to search for the pattern
    
    @group Logging: _log_start, _log_write
    @group Information: get_ip, get_id, get_readable_id, is_locally_initiated, is_encrypted, is_flushed
    @group Encryption: read_header, _start_crypto, _end_crypto, read_crypto_header, _search_for_pattern, read_encrypted_header
    @group Incoming Encryption: read_crypto_block3a, read_crypto_block3b, 
        read_crypto_block3c, read_crypto_pad3, read_crypto_ia, read_crypto_block3done
    @group Outgoing Encryption: read_crypto_block4a, read_crypto_block4b, read_crypto_pad4, read_crypto_block4done
    @group BitTorrent Handshake: _read_header, read_options, read_download_id, read_peer_id
    @group BitTorrent Messages: read_len, read_message, read_dead
    @group Data: send_message_raw, _write, data_came_in, _write_buffer, _read, _switch_to_read2, _read2
    @group Connection: _auto_close, close, sever, connection_flushed, connection_lost

    """
    
    def __init__(self, Encoder, connection, dns, id,
                 ext_handshake=False, encrypted = None, options = None):
        """Initialize the instance and start handling the connection.
        
        @type Encoder: L{Encoder}
        @param Encoder: the collection of all connections
        @type connection: L{DebTorrent.SocketHandler.SingleSocket}
        @param connection: the low-level connection to the peer
        @type dns: (C{string}, C{int})
        @param dns: the IP address and port to connect to
        @type id: C{string}
        @param id: the peer ID of the peer to connect to (will be None if 
            the connection is being initiated locally)
        @type ext_handshake: C{boolean}
        @param ext_handshake: whether the connection has already been
            handshaked by another module (optional, defaults to False)
        @type encrypted: C{int} or C{DebTorrent.BT1Crypto.Crypto}
        @param encrypted: the type of encryption the connection supports
            (0 for none), or the already created Crypto instance, if the connection
            was externally handshaked (optional, defaults to creating a new one)
        @type options: C{string}
        @param options: the options read from the externally handshaked
            connection (optional, defaults to None)
        
        """
        
        self.Encoder = Encoder
        self.connection = connection
        self.connecter = Encoder.connecter
        self.dns = dns
        self.id = id
        self.locally_initiated = (id != None)
        self.readable_id = make_readable(id)
        self.complete = False
        self.keepalive = lambda: None
        self.closed = False
        self.buffer = ''
        self.bufferlen = None
        self.log = None
        self.read = self._read
        self.write = self._write
        self.cryptmode = 0
        self.encrypter = None
        if self.locally_initiated:
            incompletecounter.increment()
            if encrypted:
                logger.info('Initiating an encrypted connection to '+self.readable_id)
                self.encrypted = True
                self.encrypter = Crypto(True)
                self.write(self.encrypter.pubkey+self.encrypter.padding())
            else:
                logger.info('Initiating an unencrypted connection to '+self.readable_id)
                self.encrypted = False
                self.write(chr(len(protocol_name)) + protocol_name + 
                    option_pattern + self.Encoder.download_id )
            self.next_len, self.next_func = 1+len(protocol_name), self.read_header
        elif ext_handshake:
            self.Encoder.connecter.external_connection_made += 1
            if encrypted:   # passed an already running encrypter
                logger.info('Received an external encrypted connection from '+self.readable_id)
                self.encrypter = encrypted
                self.encrypted = True
                self._start_crypto()
                self.next_len, self.next_func = 14, self.read_crypto_block3c
            else:
                logger.info('Received an external unencrypted connection from '+self.readable_id)
                self.encrypted = False
                self.options = options
                self.write(chr(len(protocol_name)) + protocol_name + 
                           option_pattern + self.Encoder.download_id + self.Encoder.my_id)
                self.next_len, self.next_func = 20, self.read_peer_id
        else:
            logger.info('Received an unknown connection from '+self.readable_id)
            self.encrypted = None       # don't know yet
            self.next_len, self.next_func = 1+len(protocol_name), self.read_header
        self.Encoder.raw_server.add_task(self._auto_close, 30)


    def _log_start(self):
        """Start logging detailed connection data.
        
        Only called with DEBUG = True (i.e. never). Adds an intercept function
        so that all connection writes are first written to the log file, then
        out on the connection.
        
        """
        
        self.log = open('peerlog.'+self.get_ip()+'.txt','a')
        self.log.write('connected - ')
        if self.locally_initiated:
            self.log.write('outgoing\n')
        else:
            self.log.write('incoming\n')
        self._logwritefunc = self.write
        self.write = self._log_write

    def _log_write(self, s):
        """Write to the log file, then the connection.
        
        @type s: C{string}
        @param s: the data to write
        
        """
        
        self.log.write('w:'+b2a_hex(s)+'\n')
        self._logwritefunc(s)
        

    def get_ip(self, real=False):
        """Get the IP address of the connection.
        
        @type real: C{boolean}
        @param real: whether to double check that it's the reall IP address
        @rtype: C{string}
        @return: the IP address
        
        """
        
        return self.connection.get_ip(real)

    def get_id(self):
        """Get the peer ID.
        
        @rtype: C{string}
        @return: the peer ID
        
        """
        
        return self.id

    def get_readable_id(self):
        """Get the human-readable peer ID.
        
        @rtype: C{string}
        @return: the peer ID
        
        """
        
        return self.readable_id

    def is_locally_initiated(self):
        """Determine whether the connection was locally initiated.
        
        @rtype: C{boolean}
        @return: whether the connection was locally initiated
        
        """
        
        return self.locally_initiated

    def is_encrypted(self):
        """Determine whether the connection is encrypted.
        
        @rtype: C{boolean}
        @return: whether the connection is encrypted
        
        """
        
        return bool(self.encrypted)

    def is_flushed(self):
        """Determine whether the connection is flushed.
        
        @rtype: C{boolean}
        @return: whether the connection is flushed
        
        """
        
        return self.connection.is_flushed()

    def _read_header(self, s):
        """Read the protocol header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if s == chr(len(protocol_name))+protocol_name:
            return 8, self.read_options
        return None

    def read_header(self, s):
        """Read the possibly encrypted protocol header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if self._read_header(s):
            if self.encrypted or self.Encoder.config['crypto_stealth']:
                logger.info('Dropped the connection as it was unencrypted: '+self.readable_id)
                return None
            return 8, self.read_options
        if self.locally_initiated and not self.encrypted:
            logger.info('Got a bad protocol name: '+self.readable_id)
            return None
        elif not self.Encoder.config['crypto_allowed']:
            logger.info('Dropped the connection as it was encrypted: '+self.readable_id)
            return None
        if not self.encrypted:
            self.encrypted = True
            self.encrypter = Crypto(self.locally_initiated)
        self._write_buffer(s)
        return self.encrypter.keylength, self.read_crypto_header

    ################## ENCRYPTION SUPPORT ######################

    def _start_crypto(self):
        """Setup the connection for encrypted communication."""
        self.encrypter.setrawaccess(self._read,self._write)
        self.write = self.encrypter.write
        self.read = self.encrypter.read
        if self.buffer:
            self.buffer = self.encrypter.decrypt(self.buffer)

    def _end_crypto(self):
        """Return the connection back to unencrypted communication."""
        self.read = self._read
        self.write = self._write
        self.encrypter = None

    def read_crypto_header(self, s):
        """Read the encryption key.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        self.encrypter.received_key(s)
        self.encrypter.set_skey(self.Encoder.download_id)
        if self.locally_initiated:
            if self.Encoder.config['crypto_only']:
                cryptmode = '\x00\x00\x00\x02'    # full stream encryption
            else:
                cryptmode = '\x00\x00\x00\x03'    # header or full stream
            padc = self.encrypter.padding()
            self.write( self.encrypter.block3a
                      + self.encrypter.block3b
                      + self.encrypter.encrypt(
                            ('\x00'*8)            # VC
                          + cryptmode             # acceptable crypto modes
                          + struct.pack('>h', len(padc))
                          + padc                  # PadC
                          + '\x00\x00' ) )        # no initial payload data
            self._max_search = 520
            return 1, self.read_crypto_block4a
        self.write(self.encrypter.pubkey+self.encrypter.padding())
        self._max_search = 520
        return 0, self.read_crypto_block3a

    def _search_for_pattern(self, s, pat):
        """Search for a pattern in the encrypted protocol header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @type pat: C{string}
        @param pat: the pattern to find
        @rtype: C{boolean}
        @return: whether the pattern was found
        
        """
        
        p = s.find(pat)
        if p < 0:
            if len(s) >= len(pat):
                self._max_search -= len(s)+1-len(pat)
            if self._max_search < 0:
                self.close()
                return False
            self._write_buffer(s[1-len(pat):])
            return False
        self._write_buffer(s[p+len(pat):])
        return True

    ### INCOMING CONNECTION ###

    def read_crypto_block3a(self, s):
        """Find the block3a crypto information in the connection.
        
        @type s: C{string}
        @param s: the data read from the conection
        @rtype: C{int}, C{method}
        @return: the next length to read and method to call with the data
        
        """
        
        if not self._search_for_pattern(s,self.encrypter.block3a):
            return -1, self.read_crypto_block3a     # wait for more data
        return len(self.encrypter.block3b), self.read_crypto_block3b

    def read_crypto_block3b(self, s):
        """Process the block3b crypto information in the connection.
        
        @type s: C{string}
        @param s: the data read from the conection
        @rtype: C{boolean}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if s != self.encrypter.block3b:
            logger.info('Dropped the encrypted connection due to a bad block3b: '+self.readable_id)
            return None
        self.Encoder.connecter.external_connection_made += 1
        self._start_crypto()
        return 14, self.read_crypto_block3c

    def read_crypto_block3c(self, s):
        """Read the encrypted protocol mode.
        
        @type s: C{string}
        @param s: the data read from the conection
        @rtype: C{boolean}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if s[:8] != ('\x00'*8):             # check VC
            logger.info('Dropped the encrypted connection due to a bad VC: '+self.readable_id)
            return None
        self.cryptmode = struct.unpack('>i', s[8:12])[0] % 4
        if self.cryptmode == 0:
            logger.info('Dropped the encrypted connection due to a bad crypt mode: '+self.readable_id)
            return None                     # no encryption selected
        if ( self.cryptmode == 1            # only header encryption
             and self.Encoder.config['crypto_only'] ):
            logger.info('Dropped the header-only encrypted connection: '+self.readable_id)
            return None
        padlen = (ord(s[12])<<8)+ord(s[13])
        if padlen > 512:
            logger.info('Dropped the encrypted connection due to bad padding: '+self.readable_id)
            return None
        return padlen+2, self.read_crypto_pad3

    def read_crypto_pad3(self, s):
        """Read the encrypted protocol padding.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        s = s[-2:]
        ialen = (ord(s[0])<<8)+ord(s[1])
        if ialen > 65535:
            logger.info('Dropped the encrypted connection due to very long padding: '+self.readable_id)
            return None
        if self.cryptmode == 1:
            cryptmode = '\x00\x00\x00\x01'    # header only encryption
        else:
            cryptmode = '\x00\x00\x00\x02'    # full stream encryption
        padd = self.encrypter.padding()
        self.write( ('\x00'*8)            # VC
                  + cryptmode             # encryption mode
                  + struct.pack('>h', len(padd))
                  + padd )                # PadD
        if ialen:
            return ialen, self.read_crypto_ia
        return self.read_crypto_block3done()

    def read_crypto_ia(self, s):
        """Read the initial payload data from the connection.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if DEBUG:
            self._log_start()
            self.log.write('r:'+b2a_hex(s)+'(ia)\n')
            if self.buffer:
                self.log.write('r:'+b2a_hex(self.buffer)+'(buffer)\n')
        return self.read_crypto_block3done(s)

    def read_crypto_block3done(self, ia=''):
        """Finish with the encrypted header.
        
        @type ia: C{string}
        @param ia: the initial payload data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if DEBUG:
            if not self.log:
                self._log_start()
        if self.cryptmode == 1:     # only handshake encryption
            assert not self.buffer  # oops; check for exceptions to this
            self._end_crypto()
        if ia:
            self._write_buffer(ia)
        return 1+len(protocol_name), self.read_encrypted_header

    ### OUTGOING CONNECTION ###

    def read_crypto_block4a(self, s):
        """Read the encrypted protocol header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if not self._search_for_pattern(s,self.encrypter.VC_pattern()):
            return -1, self.read_crypto_block4a     # wait for more data
        self._start_crypto()
        return 6, self.read_crypto_block4b

    def read_crypto_block4b(self, s):
        """Read the encrypted protocol mode and padding.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        self.cryptmode = struct.unpack('>i',s[:4])[0] % 4
        if self.cryptmode == 1:             # only header encryption
            if self.Encoder.config['crypto_only']:
                logger.info('Dropped the header-only encrypted connection: '+self.readable_id)
                return None
        elif self.cryptmode != 2:
            logger.info('Dropped the encrypted connection due to an unknown crypt mode: '+self.readable_id)
            return None                     # unknown encryption
        padlen = (ord(s[4])<<8)+ord(s[5])
        if padlen > 512:
            logger.info('Dropped the encrypted connection due to bad padding: '+self.readable_id)
            return None
        if padlen:
            return padlen, self.read_crypto_pad4
        return self.read_crypto_block4done()

    def read_crypto_pad4(self, s):
        """Read the encrypted protocol padding.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        # discard data
        return self.read_crypto_block4done()

    def read_crypto_block4done(self):
        """Finish with the encrypted header.
        
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if DEBUG:
            self._log_start()
        if self.cryptmode == 1:     # only handshake encryption
            if not self.buffer:  # oops; check for exceptions to this
                logger.info('Dropped the encrypted connection due to a lack of buffer: '+self.readable_id)
                return None
            self._end_crypto()
        self.write(chr(len(protocol_name)) + protocol_name + 
            option_pattern + self.Encoder.download_id)
        return 1+len(protocol_name), self.read_encrypted_header

    ### START PROTOCOL OVER ENCRYPTED CONNECTION ###

    def read_encrypted_header(self, s):
        """Read the regular protocol name header from the encrypted stream.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        return self._read_header(s)

    ################################################

    def read_options(self, s):
        """Read the options from the header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        self.options = s
        return 20, self.read_download_id

    def read_download_id(self, s):
        """Verify the torrent infohash from the header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if ( s != self.Encoder.download_id
             or not self.Encoder.check_ip(ip=self.get_ip()) ):
            logger.warning('IPs or torrent info hashes do not match: '+self.readable_id)
            return None
        if not self.locally_initiated:
            if not self.encrypted:
                self.Encoder.connecter.external_connection_made += 1
            self.write(chr(len(protocol_name)) + protocol_name + 
                option_pattern + self.Encoder.download_id + self.Encoder.my_id)
        return 20, self.read_peer_id

    def read_peer_id(self, s):
        """Read/verify the peer's ID.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if not self.encrypted and self.Encoder.config['crypto_only']:
            logger.info('Dropped the unencrypted connection: '+self.readable_id)
            return None     # allows older trackers to ping,
                            # but won't proceed w/ connections
        if not self.id:
            self.id = s
            self.readable_id = make_readable(s)
        else:
            if s != self.id:
                logger.info('Peer ID does not match: '+self.readable_id)
                return None
        self.complete = self.Encoder.got_id(self)
        if not self.complete:
            logger.warning('Connection to %r disallowed for security: '+self.readable_id, self.dns)
            return None
        if self.locally_initiated:
            self.write(self.Encoder.my_id)
            incompletecounter.decrement()
        self._switch_to_read2()
        logger.info('Handshake complete: '+self.readable_id)
        c = self.Encoder.connecter.connection_made(self)
        self.keepalive = c.send_keepalive
        return 4, self.read_len

    def read_len(self, s):
        """Read the length of the message.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        l = struct.unpack('>i', s)[0]
        if l > self.Encoder.max_len:
            logger.warning('Dropped the connection due to bad length message: '+self.readable_id)
            return None
        return l, self.read_message

    def read_message(self, s):
        """Read the message.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if s != '':
            self.connecter.got_message(self, s)
        return 4, self.read_len

    def read_dead(self, s):
        """Return None to close the connection.
        
        @type s: C{string}
        @param s: the incoming data from the connection (not used)
        @rtype: None
        @return: None
        
        """
        
        return None

    def _auto_close(self):
        """Close the connection if the handshake is not yet complete."""
        if not self.complete and not self.closed:
            logger.warning('Connection to %r dropped due to handshake taking too long: '+self.readable_id, self.dns)
            self.close()

    def close(self):
        """Close the connection."""
        if not self.closed:
            self.connection.close()
            self.sever()

    def sever(self):
        """Clean up the connection for closing."""
        if self.log:
            self.log.write('closed\n')
            self.log.close()
        self.closed = True
        del self.Encoder.connections[self.connection]
        if self.complete:
            self.connecter.connection_lost(self)
        elif self.locally_initiated:
            incompletecounter.decrement()

    def send_message_raw(self, message):
        """Write a message out on the connection.
        
        @type message: C{string}
        @param message: the data to write to the connection
        
        """
        
        self.write(message)

    def _write(self, message):
        """Write a raw message out on the connection.
        
        @type message: C{string}
        @param message: the raw data to write to the connection
        
        """
        
        if not self.closed:
            self.connection.write(message)

    def data_came_in(self, connection, s):
        """Process the incoming data on the connection.
        
        @type connection: L{DebTorrent.SocketHandler.SingleSocket}
        @param connection: the connection the data came in on (not used)
        @type s: C{string}
        @param s: the incoming data from the connection
        
        """
        
        self.read(s)

    def _write_buffer(self, s):
        """Write data back onto the buffer.
        
        @type s: C{string}
        @param s: the data to rebuffer
        
        """
        
        self.buffer = s+self.buffer

    def _read(self, s):
        """Process the data that comes in.
        
        @type s: C{string}
        @param s: the (unencrypted) incoming data from the connection
        
        """
        
        if self.log:
            self.log.write('r:'+b2a_hex(s)+'\n')
        self.Encoder.measurefunc(len(s))
        self.buffer += s
        while True:
            if self.closed:
                return
            # self.next_len = # of characters function expects
            # or 0 = all characters in the buffer
            # or -1 = wait for next read, then all characters in the buffer
            # not compatible w/ keepalives, switch out after all negotiation complete
            if self.next_len <= 0:
                m = self.buffer
                self.buffer = ''
            elif len(self.buffer) >= self.next_len:
                m = self.buffer[:self.next_len]
                self.buffer = self.buffer[self.next_len:]
            else:
                return
            try:
                x = self.next_func(m)
            except:
                logger.exception('Dropped connection due to exception: '+self.readable_id)
                self.next_len, self.next_func = 1, self.read_dead
                raise
            if x is None:
                self.close()
                return
            self.next_len, self.next_func = x
            if self.next_len < 0:  # already checked buffer
                return             # wait for additional data
            if self.bufferlen is not None:
                self._read2('')
                return

    def _switch_to_read2(self):
        """Switch from _read to the more efficient _read2 method."""
        self._write_buffer = None
        if self.encrypter:
            self.encrypter.setrawaccess(self._read2,self._write)
        else:
            self.read = self._read2
        self.bufferlen = len(self.buffer)
        self.buffer = [self.buffer]

    def _read2(self, s):
        """Efficiently process the data that comes in.
        
        More efficient, buffers the incoming data by appending it to a list
        rather than creating a new string by adding it to the end. Requires
        buffer to be a list of strings, and bufferlen to be it's total length.
        
        @type s: C{string}
        @param s: the (unencrypted) incoming data from the connection
        
        """
        
        if self.log:
            self.log.write('r:'+b2a_hex(s)+'\n')
        self.Encoder.measurefunc(len(s))
        while True:
            if self.closed:
                return
            p = self.next_len-self.bufferlen
            if self.next_len == 0:
                m = ''
            elif s:
                if p > len(s):
                    self.buffer.append(s)
                    self.bufferlen += len(s)
                    return
                self.bufferlen = len(s)-p
                self.buffer.append(s[:p])
                m = ''.join(self.buffer)
                if p == len(s):
                    self.buffer = []
                else:
                    self.buffer=[s[p:]]
                s = ''
            elif p <= 0:
                # assert len(self.buffer) == 1
                s = self.buffer[0]
                self.bufferlen = len(s)-self.next_len
                m = s[:self.next_len]
                if p == 0:
                    self.buffer = []
                else:
                    self.buffer = [s[self.next_len:]]
                s = ''
            else:
                return
            try:
                x = self.next_func(m)
            except:
                logger.exception('Dropped connection due to exception: '+self.readable_id)
                self.next_len, self.next_func = 1, self.read_dead
                raise
            if x is None:
                self.close()
                return
            self.next_len, self.next_func = x
            if self.next_len < 0:  # already checked buffer
                return             # wait for additional data
            

    def connection_flushed(self, connection):
        """Flush the connection.
        
        @type connection: L{DebTorrent.SocketHandler.SingleSocket}
        @param connection: the connection that was flushed (not used)
        
        """
        
        if self.complete:
            self.connecter.connection_flushed(self)

    def connection_lost(self, connection):
        """Sever the connection.
        
        @type connection: L{DebTorrent.SocketHandler.SingleSocket}
        @param connection: the connection that was lost (not used)
        
        """
        
        if self.Encoder.connections.has_key(connection):
            self.sever()


class _dummy_banlist:
    """A dummy list of banned peers."""
    
    def includes(self, x):
        """Check if a peer is banned (always returns False).
        
        @type x: C{string}
        @param x: the IP address of the peer to check
        @rtype: C{boolean}
        @return: whether the peer is banned
        
        """
        
        return False

class Encoder:
    """The collection of all (possibly encrypted) connections.
    
    @type raw_server: L{DebTorrent.RawServer.RawServer}
    @ivar raw_server: the server instance to use
    @type connecter: L{Connecter.Connecter}
    @ivar connecter: the Connecter instance to use
    @type my_id: C{string}
    @ivar my_id: the peer ID to use
    @type max_len: C{int}
    @ivar max_len: the maximum length message to accept
    @type schedulefunc: C{method}
    @ivar schedulefunc: method to call to schedule future function invocation
    @type keepalive_delay: C{int}
    @ivar keepalive_delay: the delay between sending keepalive messages
    @type download_id: C{string}
    @ivar download_id: the infohash of the torrent being downloaded
    @type measurefunc: C{method}
    @ivar measurefunc: method to call with the size of incoming data
    @type config: C{dictionary}
    @ivar config: the configuration parameters
    @type connections: C{dictionary}
    @ivar connections: keys are the L{DebTorrent.SocketHandler.SingleSocket} 
        connections, values are the corresponding L{Connection} instances
    @type banned: C{dictionary}
    @ivar banned: keys are IP addresses that are banned
    @type external_bans: C{class}
    @ivar external_bans: the instance to check for banned peer's in
    @type to_connect: C{list} of ((C{string}, C{int}), C{string}, C{boolean})
    @ivar to_connect: the list of IP address, port, peer ID, and whether to encrypt
    @type paused: C{boolean}
    @ivar paused: whether the download is paused
    @type max_connections: C{int}
    @ivar max_connections: the maximum number of connections to accept
    
    """
    
    def __init__(self, connecter, raw_server, my_id, max_len,
            schedulefunc, keepalive_delay, download_id, 
            measurefunc, config, bans=_dummy_banlist() ):
        """Initialize the instance.
        
        @type connecter: L{Connecter.Connecter}
        @param connecter: the Connecter instance to use
        @type raw_server: L{DebTorrent.RawServer.RawServer}
        @param raw_server: the server instance to use
        @type my_id: C{string}
        @param my_id: the peer ID to use
        @type max_len: C{int}
        @param max_len: the maximum length message to accept
        @type schedulefunc: C{method}
        @param schedulefunc: method to call to schedule future function invocation
        @type keepalive_delay: C{int}
        @param keepalive_delay: the delay between sending keepalive messages
        @type download_id: C{string}
        @param download_id: the infohash of the torrent being downloaded
        @type measurefunc: C{method}
        @param measurefunc: method to call with the size of incoming data
        @type config: C{dictionary}
        @param config: the configuration parameters
        @type bans: C{class}
        @param bans: the instance to check for banned peer's in
            (optional, defaults to an instance of L{_dummy_banlist})
        
        """
        
        self.raw_server = raw_server
        self.connecter = connecter
        self.my_id = my_id
        self.max_len = max_len
        self.schedulefunc = schedulefunc
        self.keepalive_delay = keepalive_delay
        self.download_id = download_id
        self.measurefunc = measurefunc
        self.config = config
        self.connections = {}
        self.banned = {}
        self.external_bans = bans
        self.to_connect = []
        self.paused = False
        if self.config['max_connections'] == 0:
            self.max_connections = 2 ** 30
        else:
            self.max_connections = self.config['max_connections']
        schedulefunc(self.send_keepalives, keepalive_delay)

    def send_keepalives(self):
        """Periodically send keepalive messages on all the connections."""
        self.schedulefunc(self.send_keepalives, self.keepalive_delay)
        if self.paused:
            return
        logger.debug('Sending keepalive messages to all connected peers')
        for c in self.connections.values():
            c.keepalive()

    def start_connections(self, list):
        """Add many connections from a list to a queue to start.
        
        @type list: C{list} of ((C{string}, C{int}), C{string}, C{boolean})
        @param list: the list of IP address, port, peer ID, and whether to encrypt
        
        """
        
        if not self.to_connect:
            self.raw_server.add_task(self._start_connection_from_queue)
        self.to_connect = list

    def _start_connection_from_queue(self):
        """Start a connection in the queue."""
        if self.connecter.external_connection_made:
            max_initiate = self.config['max_initiate']
        else:
            max_initiate = int(self.config['max_initiate']*1.5)
        cons = len(self.connections)
        if cons >= self.max_connections or cons >= max_initiate:
            delay = 60
        elif self.paused or incompletecounter.toomany():
            delay = 1
        else:
            delay = 0
            dns, id, encrypted = self.to_connect.pop(0)
            self.start_connection(dns, id, encrypted)
        if self.to_connect:
            self.raw_server.add_task(self._start_connection_from_queue, delay)

    def start_connection(self, dns, id, encrypted = None):
        """Start a connection to a peer.
        
        @type dns: (C{string}, C{int})
        @param dns: the IP address and port to connect to
        @type id: C{string}
        @param id: the peer ID of the peer
        @type encrypted: C{boolean}
        @param encrypted: whether to encrypt the connection
            (optional, defaults to True)
        @rtype: C{boolean}
        @return: False if an error occurs
       
        """
        
        if self.paused:
            logger.info('Not connecting due to being paused')
            return True
        if len(self.connections) >= self.max_connections:
            logger.info('Not connecting due to too many connections: '+str(len(self.connections))+' >= '+str(self.max_connections))
            return True
        if id == self.my_id:
            logger.info('Not connecting due to it being my ID: '+dns[0])
            return True
        if not self.check_ip(ip=dns[0]):
            logger.info('Not connecting due to the IP being banned: '+dns[0])
            return True
        if self.config['crypto_only']:
            if encrypted is None or encrypted:  # fails on encrypted = 0
                encrypted = True
            else:
                logger.info('Not connecting due to not being encrypted: '+dns[0])
                return True
        for v in self.connections.values():
            if v is None:
                continue
            if id and v.id == id:
                logger.info('Not connecting due to a matching peer ID: '+id)
                return True
            ip = v.get_ip(True)
            if self.config['security'] and ip != 'unknown' and ip == dns[0]:
                logger.info('Not connecting due to a matching IP: '+ip)
                return True
            if dns == v.dns:
                logger.info('Not connecting due to already being connected: %r', dns)
        try:
            logger.debug('initiating connection to: '+str(dns)+', '+str(id)+', '+str(encrypted))
            c = self.raw_server.start_connection(dns)
            con = Connection(self, c, dns, id, encrypted = encrypted)
            self.connections[c] = con
            c.set_handler(con)
        except socketerror:
            return False
        return True

    def _start_connection(self, dns, id, encrypted = None):
        """Schedule the start of a connection to a peer.
        
        @type dns: (C{string}, C{int})
        @param dns: the IP address and port to connect to
        @type id: C{string}
        @param id: the peer ID of the peer
        @type encrypted: C{boolean}
        @param encrypted: whether to encrypt the connection
            (optional, defaults to True)
       
        """
        
        def foo(self=self, dns=dns, id=id, encrypted=encrypted):
            self.start_connection(dns, id, encrypted)
        self.schedulefunc(foo, 0)

    def check_ip(self, connection=None, ip=None):
        """Check whether the connection to the IP is allowed.
        
        @type connection: L{DebTorrent.SocketHandler.SingleSocket}
        @param connection: the connection whose IP should be checked
            (optional, but one of connection/ip must be specified)
        @type ip: C{string}
        @param ip: the IP address of the connection
            (optional, but one of connection/ip must be specified)
        @rtype: C{boolean}
        @return: whether the connection is allowed
        
        """
        
        if not ip:
            ip = connection.get_ip(True)
        if self.config['security'] and self.banned.has_key(ip):
            return False
        if self.external_bans.includes(ip):
            return False
        return True

    def got_id(self, connection):
        """Check whether the connection to the peer ID is allowed.
        
        Checks whether the peer ID is ours, or is already connected on another
        connection. If the security config option is set, it also checks if
        the IP is already connected, in which case this one is not allowed.
        
        @type connection: L{Connection}
        @param connection: the connection whose peer ID should be checked
        @rtype: C{boolean}
        @return: whether the connection is allowed
        
        """
        
        if connection.id == self.my_id:
            self.connecter.external_connection_made -= 1
            logger.debug('New connection matches our peer ID')
            return False
        ip = connection.get_ip(True)
        for v in self.connections.values():
            if connection is not v:
                if connection.id == v.id:
                    if ip == v.get_ip(True):
                        logger.debug('Closing an old connection to the same peer')
                        v.close()
                    else:
                        logger.debug('New peer ID is already connected')
                        return False
                if self.config['security'] and ip != 'unknown' and ip == v.get_ip(True):
                    logger.debug('Closing an old connection to the same IP')
                    v.close()
        return True

    def external_connection_made(self, connection):
        """Process an externally made connection.
        
        @type connection: L{DebTorrent.SocketHandler.SingleSocket}
        @param connection: the connection that was made
        @rtype: C{boolean}
        @return: whether the connection is accepted
        
        """
        
        if self.paused or len(self.connections) >= self.max_connections:
            if self.paused:
                logger.info('Not allowing connection due to being paused')
            else:
                logger.info('Not allowing connection due to too many connections: '+
                            str(len(self.connections))+' >= '+str(self.max_connections))
            connection.close()
            return False
        dns = connection.getpeername()
        logger.info("Reveived a connection from: %r", dns)
        con = Connection(self, connection, dns, None)
        self.connections[connection] = con
        connection.set_handler(con)
        return True

    def externally_handshaked_connection_made(self, connection, options,
                                              already_read, encrypted = None):
        """Process an externally handshaked connection.
        
        @type connection: L{DebTorrent.SocketHandler.SingleSocket}
        @param connection: the connection that was made
        @type options: C{string}
        @param options: the options read from the externally handshaked
            connection (optional, defaults to None)
        @type already_read: C{string}
        @param already_read: the data that has already been received on the connection
        @type encrypted: C{DebTorrent.BT1Crypto.Crypto}
        @param encrypted: the already created Crypto instance, if the connection
            was externally handshaked (optional, defaults to creating a new one)
        @rtype: C{boolean}
        @return: whether the connection is accepted
        
        """
        
        if self.paused:
            logger.info('Not allowing external connection due to being paused')
            connection.close()
            return False
        if len(self.connections) >= self.max_connections:
            logger.info('Not allowing external connection due to too many connections: '+
                        str(len(self.connections))+' >= '+str(self.max_connections))
            connection.close()
            return False
        if not self.check_ip(connection=connection):
            logger.info('Not allowing external connection due to the IP being banned: '+dns[0])
            connection.close()
            return False
        dns = connection.getpeername()
        logger.info("Received an externally handled connection from: %r", dns)
        con = Connection(self, connection, dns, None,
                ext_handshake = True, encrypted = encrypted, options = options)
        self.connections[connection] = con
        connection.set_handler(con)
        if already_read:
            con.data_came_in(con, already_read)
        return True

    def close_all(self):
        """Close all the currently open connections."""
        for c in self.connections.values():
            c.close()
        self.connections = {}

    def ban(self, ip):
        """Ban an IP address from ever connecting again.
        
        @type ip: C{string}
        @param ip: the IP address to ban
        
        """
        
        logger.info('Banned: '+ip)
        self.banned[ip] = 1

    def pause(self, flag):
        """Set the paused flag.
        
        When paused, no new connections are made and no keepalives are sent.
        
        @type flag: C{boolean}
        @param flag: whether the download is paused
        
        """
        
        self.paused = flag
