#!/usr/bin/env python
#
# types_extras.py - tests for extras types conversions
#
# Copyright (C) 2008-2010 Federico Di Gregorio  <fog@debian.org>
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
# License for more details.

import re
import sys
import six
from decimal import Decimal
from datetime import date, datetime
from functools import wraps

from psycopg2cffi.tests.psycopg2_tests.testutils import unittest, \
        skip_if_no_uuid, skip_before_postgres, _u, ConnectingTestCase, \
        decorate_all_tests

import psycopg2cffi as psycopg2
from psycopg2cffi import extras, extensions


def filter_scs(conn, s):
    if conn.get_parameter_status("standard_conforming_strings") == 'off':
        return s
    else:
        return s.replace(b"E'", b"'")

class TypesExtrasTests(ConnectingTestCase):
    """Test that all type conversions are working."""

    def execute(self, *args):
        curs = self.conn.cursor()
        curs.execute(*args)
        return curs.fetchone()[0]

    @skip_if_no_uuid
    def testUUID(self):
        import uuid
        extras.register_uuid()
        u = uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350')
        s = self.execute("SELECT %s AS foo", (u,))
        self.failUnless(u == s)
        # must survive NULL cast to a uuid
        s = self.execute("SELECT NULL::uuid AS foo")
        self.failUnless(s is None)

    @skip_if_no_uuid
    def testUUIDARRAY(self):
        import uuid
        extras.register_uuid()
        u = [uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350'), uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e352')]
        s = self.execute("SELECT %s AS foo", (u,))
        self.failUnless(u == s)
        # array with a NULL element
        u = [uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350'), None]
        s = self.execute("SELECT %s AS foo", (u,))
        self.failUnless(u == s)
        # must survive NULL cast to a uuid[]
        s = self.execute("SELECT NULL::uuid[] AS foo")
        self.failUnless(s is None)
        # what about empty arrays?
        s = self.execute("SELECT '{}'::uuid[] AS foo")
        self.failUnless(type(s) == list and len(s) == 0)

    def testINET(self):
        extras.register_inet()
        i = extras.Inet("192.168.1.0/24")
        s = self.execute("SELECT %s AS foo", (i,))
        self.failUnless(i.addr == s.addr)
        # must survive NULL cast to inet
        s = self.execute("SELECT NULL::inet AS foo")
        self.failUnless(s is None)

    def testINETARRAY(self):
        extras.register_inet()
        i = extras.Inet("192.168.1.0/24")
        s = self.execute("SELECT %s AS foo", ([i],))
        self.failUnless(i.addr == s[0].addr)
        # must survive NULL cast to inet
        s = self.execute("SELECT NULL::inet[] AS foo")
        self.failUnless(s is None)

    def test_inet_conform(self):
        from psycopg2cffi.extras import Inet
        i = Inet("192.168.1.0/24")
        a = extensions.adapt(i)
        a.prepare(self.conn)
        self.assertEqual(
            filter_scs(self.conn, b"E'192.168.1.0/24'::inet"),
            a.getquoted())

        # adapts ok with unicode too
        i = Inet(_u(b"192.168.1.0/24"))
        a = extensions.adapt(i)
        a.prepare(self.conn)
        self.assertEqual(
            filter_scs(self.conn, b"E'192.168.1.0/24'::inet"),
            a.getquoted())

    def test_adapt_fail(self):
        class Foo(object): pass
        self.assertRaises(psycopg2.ProgrammingError,
            extensions.adapt, Foo(), extensions.ISQLQuote, None)
        try:
            extensions.adapt(Foo(), extensions.ISQLQuote, None)
        except psycopg2.ProgrammingError as err:
            self.failUnless(str(err) == "can't adapt type 'Foo'")


def skip_if_no_hstore(f):
    @wraps(f)
    def skip_if_no_hstore_(self):
        from psycopg2cffi.extras import HstoreAdapter
        oids = HstoreAdapter.get_oids(self.conn)
        if oids is None or not oids[0]:
            return self.skipTest("hstore not available in test database")
        return f(self)

    return skip_if_no_hstore_

class HstoreTestCase(ConnectingTestCase):
    def test_adapt_8(self):
        if self.conn.server_version >= 90000:
            return self.skipTest("skipping dict adaptation with PG pre-9 syntax")

        from psycopg2cffi.extras import HstoreAdapter

        o = {'a': '1', 'b': "'", 'c': None}
        if self.conn.encoding == 'UTF8':
            o['d'] = _u(b'\xc3\xa0')

        a = HstoreAdapter(o)
        a.prepare(self.conn)
        q = a.getquoted()

        self.assert_(q.startswith(b"(("), q)
        ii = q[1:-1].split(b"||")
        ii.sort()

        self.assertEqual(len(ii), len(o))
        self.assertEqual(ii[0], filter_scs(self.conn, b"(E'a' => E'1')"))
        self.assertEqual(ii[1], filter_scs(self.conn, b"(E'b' => E'''')"))
        self.assertEqual(ii[2], filter_scs(self.conn, b"(E'c' => NULL)"))
        if 'd' in o:
            encc = _u(b'\xc3\xa0').encode(extensions.encodings[self.conn.encoding])
            self.assertEqual(ii[3],
                    filter_scs(self.conn, b"(E'd' => E'" + encc + b"')"))

    def test_adapt_9(self):
        if self.conn.server_version < 90000:
            return self.skipTest("skipping dict adaptation with PG 9 syntax")

        from psycopg2cffi.extras import HstoreAdapter

        o = {'a': '1', 'b': "'", 'c': None}
        if self.conn.encoding == 'UTF8':
            o['d'] = _u(b'\xc3\xa0')

        a = HstoreAdapter(o)
        a.prepare(self.conn)
        q = a.getquoted()

        m = re.match(br'hstore\(ARRAY\[([^\]]+)\], ARRAY\[([^\]]+)\]\)', q)
        self.assert_(m, repr(q))

        kk = m.group(1).split(b", ")
        vv = m.group(2).split(b", ")
        ii = list(zip(kk, vv))
        ii.sort()

        def f(*args):
            return tuple([filter_scs(self.conn, s) for s in args])

        self.assertEqual(len(ii), len(o))
        self.assertEqual(ii[0], f(b"E'a'", b"E'1'"))
        self.assertEqual(ii[1], f(b"E'b'", b"E''''"))
        self.assertEqual(ii[2], f(b"E'c'", b"NULL"))
        if 'd' in o:
            encc = _u(b'\xc3\xa0').encode(extensions.encodings[self.conn.encoding])
            self.assertEqual(ii[3], f(b"E'd'", b"E'" + encc + b"'"))

    def test_parse(self):
        from psycopg2cffi.extras import HstoreAdapter

        def ok(s, d):
            self.assertEqual(HstoreAdapter.parse(s, None), d)

        ok(None, None)
        ok('', {})
        ok('"a"=>"1", "b"=>"2"', {'a': '1', 'b': '2'})
        ok('"a"  => "1" ,"b"  =>  "2"', {'a': '1', 'b': '2'})
        ok('"a"=>NULL, "b"=>"2"', {'a': None, 'b': '2'})
        ok(r'"a"=>"\"", "\""=>"2"', {'a': '"', '"': '2'})
        ok('"a"=>"\'", "\'"=>"2"', {'a': "'", "'": '2'})
        ok('"a"=>"1", "b"=>NULL', {'a': '1', 'b': None})
        ok(r'"a\\"=>"1"', {'a\\': '1'})
        ok(r'"a\""=>"1"', {'a"': '1'})
        ok(r'"a\\\""=>"1"', {r'a\"': '1'})
        ok(r'"a\\\\\""=>"1"', {r'a\\"': '1'})

        def ko(s):
            self.assertRaises(psycopg2.InterfaceError,
                HstoreAdapter.parse, s, None)

        ko('a')
        ko('"a"')
        ko(r'"a\\""=>"1"')
        ko(r'"a\\\\""=>"1"')
        ko('"a=>"1"')
        ko('"a"=>"1", "b"=>NUL')

    @skip_if_no_hstore
    def test_register_conn(self):
        from psycopg2cffi.extras import register_hstore

        register_hstore(self.conn)
        cur = self.conn.cursor()
        cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
        t = cur.fetchone()
        self.assert_(t[0] is None)
        self.assertEqual(t[1], {})
        self.assertEqual(t[2], {'a': 'b'})

    @skip_if_no_hstore
    def test_register_curs(self):
        from psycopg2cffi.extras import register_hstore

        cur = self.conn.cursor()
        register_hstore(cur)
        cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
        t = cur.fetchone()
        self.assert_(t[0] is None)
        self.assertEqual(t[1], {})
        self.assertEqual(t[2], {'a': 'b'})

    @skip_if_no_hstore
    def test_register_unicode(self):
        from psycopg2cffi.extras import register_hstore

        register_hstore(self.conn, unicode=True)
        cur = self.conn.cursor()
        cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
        t = cur.fetchone()
        self.assert_(t[0] is None)
        self.assertEqual(t[1], {})
        self.assertEqual(t[2], {_u(b'a'): _u(b'b')})
        self.assert_(isinstance(list(t[2].keys())[0], six.text_type))
        self.assert_(isinstance(list(t[2].values())[0], six.text_type))

    @skip_if_no_hstore
    def test_register_globally(self):
        from psycopg2cffi.extras import register_hstore, HstoreAdapter

        oids = HstoreAdapter.get_oids(self.conn)
        try:
            register_hstore(self.conn, globally=True)
            conn2 = self.connect()
            try:
                cur2 = self.conn.cursor()
                cur2.execute("select 'a => b'::hstore")
                r = cur2.fetchone()
                self.assert_(isinstance(r[0], dict))
            finally:
                conn2.close()
        finally:
            extensions.string_types.pop(oids[0][0])

        # verify the caster is not around anymore
        cur = self.conn.cursor()
        cur.execute("select 'a => b'::hstore")
        r = cur.fetchone()
        self.assert_(isinstance(r[0], str))

    @skip_if_no_hstore
    def test_roundtrip(self):
        from psycopg2cffi.extras import register_hstore
        register_hstore(self.conn)
        cur = self.conn.cursor()

        def ok(d):
            cur.execute("select %s", (d,))
            d1 = cur.fetchone()[0]
            self.assertEqual(len(d), len(d1))
            for k in d:
                self.assert_(k in d1, k)
                self.assertEqual(d[k], d1[k])

        ok({})
        ok({'a': 'b', 'c': None})

        ab = map(chr, range(32, 128))
        ok(dict(zip(ab, ab)))
        ok({''.join(ab): ''.join(ab)})

        self.conn.set_client_encoding('latin1')
        if sys.version_info[0] < 3:
            ab = map(chr, range(32, 127) + range(160, 255))
        else:
            ab = bytes(list(range(32, 127)) + list(range(160, 255))).decode('latin1')

        ok({''.join(ab): ''.join(ab)})
        ok(dict(zip(ab, ab)))

    @skip_if_no_hstore
    def test_roundtrip_unicode(self):
        from psycopg2cffi.extras import register_hstore
        register_hstore(self.conn, unicode=True)
        cur = self.conn.cursor()

        def ok(d):
            cur.execute("select %s", (d,))
            d1 = cur.fetchone()[0]
            self.assertEqual(len(d), len(d1))
            for k, v in d1.items():
                self.assert_(k in d, k)
                self.assertEqual(d[k], v)
                self.assert_(isinstance(k, six.text_type))
                self.assert_(v is None or isinstance(v, six.text_type))

        ok({})
        ok({'a': 'b', 'c': None, 'd': _u(b'\xe2\x82\xac'), _u(b'\xe2\x98\x83'): 'e'})

        ab = map(six.unichr, range(1, 1024))
        ok({_u(b'').join(ab): _u(b'').join(ab)})
        ok(dict(zip(ab, ab)))

    @skip_if_no_hstore
    def test_oid(self):
        cur = self.conn.cursor()
        cur.execute("select 'hstore'::regtype::oid")
        oid = cur.fetchone()[0]

        # Note: None as conn_or_cursor is just for testing: not public
        # interface and it may break in future.
        from psycopg2cffi.extras import register_hstore
        register_hstore(None, globally=True, oid=oid)
        try:
            cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
            t = cur.fetchone()
            self.assert_(t[0] is None)
            self.assertEqual(t[1], {})
            self.assertEqual(t[2], {'a': 'b'})

        finally:
            extensions.string_types.pop(oid)

    @skip_if_no_hstore
    @skip_before_postgres(8, 3)
    def test_roundtrip_array(self):
        from psycopg2cffi.extras import register_hstore
        register_hstore(self.conn)

        ds = []
        ds.append({})
        ds.append({'a': 'b', 'c': None})

        ab = map(chr, range(32, 128))
        ds.append(dict(zip(ab, ab)))
        ds.append({''.join(ab): ''.join(ab)})

        self.conn.set_client_encoding('latin1')
        if sys.version_info[0] < 3:
            ab = map(chr, range(32, 127) + range(160, 255))
        else:
            ab = bytes(list(range(32, 127)) + list(range(160, 255))).decode('latin1')

        ds.append({''.join(ab): ''.join(ab)})
        ds.append(dict(zip(ab, ab)))

        cur = self.conn.cursor()
        cur.execute("select %s", (ds,))
        ds1 = cur.fetchone()[0]
        self.assertEqual(ds, ds1)

    @skip_if_no_hstore
    @skip_before_postgres(8, 3)
    def test_array_cast(self):
        from psycopg2cffi.extras import register_hstore
        register_hstore(self.conn)
        cur = self.conn.cursor()
        cur.execute("select array['a=>1'::hstore, 'b=>2'::hstore];")
        a = cur.fetchone()[0]
        self.assertEqual(a, [{'a': '1'}, {'b': '2'}])

    @skip_if_no_hstore
    def test_array_cast_oid(self):
        cur = self.conn.cursor()
        cur.execute("select 'hstore'::regtype::oid, 'hstore[]'::regtype::oid")
        oid, aoid = cur.fetchone()

        from psycopg2cffi.extras import register_hstore
        register_hstore(None, globally=True, oid=oid, array_oid=aoid)
        try:
            cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore, '{a=>b}'::hstore[]")
            t = cur.fetchone()
            self.assert_(t[0] is None)
            self.assertEqual(t[1], {})
            self.assertEqual(t[2], {'a': 'b'})
            self.assertEqual(t[3], [{'a': 'b'}])

        finally:
            extensions.string_types.pop(oid)
            extensions.string_types.pop(aoid)

    @skip_if_no_hstore
    def test_non_dbapi_connection(self):
        from psycopg2cffi.extras import RealDictConnection
        from psycopg2cffi.extras import register_hstore

        conn = self.connect(connection_factory=RealDictConnection)
        try:
            register_hstore(conn)
            curs = conn.cursor()
            curs.execute("select ''::hstore as x")
            self.assertEqual(curs.fetchone()['x'], {})
        finally:
            conn.close()

        conn = self.connect(connection_factory=RealDictConnection)
        try:
            curs = conn.cursor()
            register_hstore(curs)
            curs.execute("select ''::hstore as x")
            self.assertEqual(curs.fetchone()['x'], {})
        finally:
            conn.close()


def skip_if_no_composite(f):
    @wraps(f)
    def skip_if_no_composite_(self):
        if self.conn.server_version < 80000:
            return self.skipTest(
                "server version %s doesn't support composite types"
                % self.conn.server_version)

        return f(self)

    return skip_if_no_composite_

class AdaptTypeTestCase(ConnectingTestCase):
    @skip_if_no_composite
    def test_none_in_record(self):
        curs = self.conn.cursor()
        s = curs.mogrify("SELECT %s;", [(42, None)])
        self.assertEqual(b"SELECT (42, NULL);", s)
        curs.execute("SELECT %s;", [(42, None)])
        d = curs.fetchone()[0]
        self.assertEqual("(42,)", d)

    def test_none_fast_path(self):
        # the None adapter is not actually invoked in regular adaptation
        ext = extensions

        class WonkyAdapter(object):
            def __init__(self, obj): pass
            def getquoted(self): return "NOPE!"

        curs = self.conn.cursor()

        orig_adapter = ext.adapters[type(None), ext.ISQLQuote]
        try:
            ext.register_adapter(type(None), WonkyAdapter)
            self.assertEqual(ext.adapt(None).getquoted(), "NOPE!")

            s = curs.mogrify("SELECT %s;", (None,))
            self.assertEqual(b"SELECT NULL;", s)

        finally:
            ext.register_adapter(type(None), orig_adapter)

    def test_tokenization(self):
        from psycopg2cffi.extras import CompositeCaster
        def ok(s, v):
            self.assertEqual(CompositeCaster.tokenize(s), v)

        ok("(,)", [None, None])
        ok('(,"")', [None, ''])
        ok('(hello,,10.234,2010-11-11)', ['hello', None, '10.234', '2010-11-11'])
        ok('(10,"""")', ['10', '"'])
        ok('(10,",")', ['10', ','])
        ok(r'(10,"\\")', ['10', '\\'])
        ok(r'''(10,"\\',""")''', ['10', '''\\',"'''])
        ok('(10,"(20,""(30,40)"")")', ['10', '(20,"(30,40)")'])
        ok('(10,"(20,""(30,""""(40,50)"""")"")")', ['10', '(20,"(30,""(40,50)"")")'])
        ok('(,"(,""(a\nb\tc)"")")', [None, '(,"(a\nb\tc)")'])
        bytelist = [chr(i) for i in range(1, 128)]
        ok('(\x01,\x02,\x03,\x04,\x05,\x06,\x07,\x08,"\t","\n","\x0b",'
           '"\x0c","\r",\x0e,\x0f,\x10,\x11,\x12,\x13,\x14,\x15,\x16,'
           '\x17,\x18,\x19,\x1a,\x1b,\x1c,\x1d,\x1e,\x1f," ",!,"""",#,'
           '$,%,&,\',"(",")",*,+,",",-,.,/,0,1,2,3,4,5,6,7,8,9,:,;,<,=,>,?,'
           '@,A,B,C,D,E,F,G,H,I,J,K,L,M,N,O,P,Q,R,S,T,U,V,W,X,Y,Z,[,"\\\\",],'
           '^,_,`,a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z,{,|,},'
           '~,\x7f)',
           bytelist)
        ok('(,"\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f'
           '\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !'
           '""#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\\\]'
           '^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f")',
           [None, ''.join(bytelist)])

    def test_register_unicode(self):
        from psycopg2cffi._impl import typecasts
        cur = self.conn.cursor()
        extensions.register_type(
            extensions.new_type((705,), "UNKNOWN", extensions.UNICODE))
        cur.execute(_u(b"SELECT '\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e' AS japanese;"))
        res = cur.fetchall()
        assert res == [(_u(b'\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e'),)]
        cur.execute(b"SELECT '\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e' AS japanese;")
        res = cur.fetchall()
        assert res == [(_u(b'\xe6\x97\xa5\xe6\x9c\xac\xe8\xaa\x9e'),)]
        # Restore to default
        extensions.register_type(typecasts.UNKNOWN)

    @skip_if_no_composite
    def test_cast_composite(self):
        oid = self._create_type("type_isd",
            [('anint', 'integer'), ('astring', 'text'), ('adate', 'date')])

        t = extras.register_composite("type_isd", self.conn)
        self.assertEqual(t.name, 'type_isd')
        self.assertEqual(t.schema, 'public')
        self.assertEqual(t.oid, oid)
        self.assert_(issubclass(t.type, tuple))
        self.assertEqual(t.attnames, ['anint', 'astring', 'adate'])
        self.assertEqual(t.atttypes, [23,25,1082])

        curs = self.conn.cursor()
        r = (10, 'hello', date(2011,1,2))
        curs.execute("select %s::type_isd;", (r,))
        v = curs.fetchone()[0]
        self.assert_(isinstance(v, t.type))
        self.assertEqual(v[0], 10)
        self.assertEqual(v[1], "hello")
        self.assertEqual(v[2], date(2011,1,2))

        try:
            from collections import namedtuple
        except ImportError:
            pass
        else:
            self.assert_(t.type is not tuple)
            self.assertEqual(v.anint, 10)
            self.assertEqual(v.astring, "hello")
            self.assertEqual(v.adate, date(2011,1,2))

    @skip_if_no_composite
    def test_empty_string(self):
        # issue #141
        self._create_type("type_ss", [('s1', 'text'), ('s2', 'text')])
        curs = self.conn.cursor()
        extras.register_composite("type_ss", curs)

        def ok(t):
            curs.execute("select %s::type_ss", (t,))
            rv = curs.fetchone()[0]
            self.assertEqual(t, rv)

        ok(('a', 'b'))
        ok(('a', ''))
        ok(('', 'b'))
        ok(('a', None))
        ok((None, 'b'))
        ok(('', ''))
        ok((None, None))

    @skip_if_no_composite
    def test_cast_nested(self):
        self._create_type("type_is",
            [("anint", "integer"), ("astring", "text")])
        self._create_type("type_r_dt",
            [("adate", "date"), ("apair", "type_is")])
        self._create_type("type_r_ft",
            [("afloat", "float8"), ("anotherpair", "type_r_dt")])

        extras.register_composite("type_is", self.conn)
        extras.register_composite("type_r_dt", self.conn)
        extras.register_composite("type_r_ft", self.conn)

        curs = self.conn.cursor()
        r = (0.25, (date(2011,1,2), (42, "hello")))
        curs.execute("select %s::type_r_ft;", (r,))
        v = curs.fetchone()[0]

        self.assertEqual(r, v)

        try:
            from collections import namedtuple
        except ImportError:
            pass
        else:
            self.assertEqual(v.anotherpair.apair.astring, "hello")

    @skip_if_no_composite
    def test_register_on_cursor(self):
        self._create_type("type_ii", [("a", "integer"), ("b", "integer")])

        curs1 = self.conn.cursor()
        curs2 = self.conn.cursor()
        extras.register_composite("type_ii", curs1)
        curs1.execute("select (1,2)::type_ii")
        self.assertEqual(curs1.fetchone()[0], (1,2))
        curs2.execute("select (1,2)::type_ii")
        self.assertEqual(curs2.fetchone()[0], "(1,2)")

    @skip_if_no_composite
    def test_register_on_connection(self):
        self._create_type("type_ii", [("a", "integer"), ("b", "integer")])

        conn1 = self.connect()
        conn2 = self.connect()
        try:
            extras.register_composite("type_ii", conn1)
            curs1 = conn1.cursor()
            curs2 = conn2.cursor()
            curs1.execute("select (1,2)::type_ii")
            self.assertEqual(curs1.fetchone()[0], (1,2))
            curs2.execute("select (1,2)::type_ii")
            self.assertEqual(curs2.fetchone()[0], "(1,2)")
        finally:
            conn1.close()
            conn2.close()

    @skip_if_no_composite
    def test_register_globally(self):
        self._create_type("type_ii", [("a", "integer"), ("b", "integer")])

        conn1 = self.connect()
        conn2 = self.connect()
        try:
            t = extras.register_composite("type_ii", conn1, globally=True)
            try:
                curs1 = conn1.cursor()
                curs2 = conn2.cursor()
                curs1.execute("select (1,2)::type_ii")
                self.assertEqual(curs1.fetchone()[0], (1,2))
                curs2.execute("select (1,2)::type_ii")
                self.assertEqual(curs2.fetchone()[0], (1,2))
            finally:
                # drop the registered typecasters to help the refcounting
                # script to return precise values.
                del extensions.string_types[t.typecaster.values[0]]
                if t.array_typecaster:
                    del extensions.string_types[
                        t.array_typecaster.values[0]]

        finally:
            conn1.close()
            conn2.close()

    @skip_if_no_composite
    def test_composite_namespace(self):
        curs = self.conn.cursor()
        curs.execute("""
            select nspname from pg_namespace
            where nspname = 'typens';
            """)
        if not curs.fetchone():
            curs.execute("create schema typens;")
            self.conn.commit()

        self._create_type("typens.typens_ii",
            [("a", "integer"), ("b", "integer")])
        t = extras.register_composite(
            "typens.typens_ii", self.conn)
        self.assertEqual(t.schema, 'typens')
        curs.execute("select (4,8)::typens.typens_ii")
        self.assertEqual(curs.fetchone()[0], (4,8))

    @skip_if_no_composite
    @skip_before_postgres(8, 4)
    def test_composite_array(self):
        oid = self._create_type("type_isd",
            [('anint', 'integer'), ('astring', 'text'), ('adate', 'date')])

        t = extras.register_composite("type_isd", self.conn)

        curs = self.conn.cursor()
        r1 = (10, 'hello', date(2011,1,2))
        r2 = (20, 'world', date(2011,1,3))
        curs.execute("select %s::type_isd[];", ([r1, r2],))
        v = curs.fetchone()[0]
        self.assertEqual(len(v), 2)
        self.assert_(isinstance(v[0], t.type))
        self.assertEqual(v[0][0], 10)
        self.assertEqual(v[0][1], "hello")
        self.assertEqual(v[0][2], date(2011,1,2))
        self.assert_(isinstance(v[1], t.type))
        self.assertEqual(v[1][0], 20)
        self.assertEqual(v[1][1], "world")
        self.assertEqual(v[1][2], date(2011,1,3))

    @skip_if_no_composite
    def test_wrong_schema(self):
        oid = self._create_type("type_ii", [("a", "integer"), ("b", "integer")])
        from psycopg2cffi.extras import CompositeCaster
        c = CompositeCaster('type_ii', oid, [('a', 23), ('b', 23), ('c', 23)])
        curs = self.conn.cursor()
        extensions.register_type(c.typecaster, curs)
        curs.execute("select (1,2)::type_ii")
        self.assertRaises(psycopg2.DataError, curs.fetchone)

    @skip_if_no_composite
    @skip_before_postgres(8, 4)
    def test_from_tables(self):
        curs = self.conn.cursor()
        curs.execute("""create table ctest1 (
            id integer primary key,
            temp int,
            label varchar
        );""")

        curs.execute("""alter table ctest1 drop temp;""")

        curs.execute("""create table ctest2 (
            id serial primary key,
            label varchar,
            test_id integer references ctest1(id)
        );""")

        curs.execute("""insert into ctest1 (id, label) values
                (1, 'test1'),
                (2, 'test2');""")
        curs.execute("""insert into ctest2 (label, test_id) values
                ('testa', 1),
                ('testb', 1),
                ('testc', 2),
                ('testd', 2);""")

        extras.register_composite("ctest1", curs)
        extras.register_composite("ctest2", curs)

        curs.execute("""
            select ctest1, array_agg(ctest2) as test2s
            from (
                select ctest1, ctest2
                from ctest1 inner join ctest2 on ctest1.id = ctest2.test_id
                order by ctest1.id, ctest2.label
            ) x group by ctest1;""")

        r = curs.fetchone()
        self.assertEqual(r[0], (1, 'test1'))
        self.assertEqual(r[1], [(1, 'testa', 1), (2, 'testb', 1)])
        r = curs.fetchone()
        self.assertEqual(r[0], (2, 'test2'))
        self.assertEqual(r[1], [(3, 'testc', 2), (4, 'testd', 2)])

    @skip_if_no_composite
    def test_non_dbapi_connection(self):
        from psycopg2cffi.extras import RealDictConnection
        from psycopg2cffi.extras import register_composite
        self._create_type("type_ii", [("a", "integer"), ("b", "integer")])

        conn = self.connect(connection_factory=RealDictConnection)
        try:
            register_composite('type_ii', conn)
            curs = conn.cursor()
            curs.execute("select '(1,2)'::type_ii as x")
            self.assertEqual(curs.fetchone()['x'], (1,2))
        finally:
            conn.close()

        conn = self.connect(connection_factory=RealDictConnection)
        try:
            curs = conn.cursor()
            register_composite('type_ii', conn)
            curs.execute("select '(1,2)'::type_ii as x")
            self.assertEqual(curs.fetchone()['x'], (1,2))
        finally:
            conn.close()

    @skip_if_no_composite
    def test_subclass(self):
        oid = self._create_type("type_isd",
            [('anint', 'integer'), ('astring', 'text'), ('adate', 'date')])

        from psycopg2cffi.extras import register_composite, CompositeCaster

        class DictComposite(CompositeCaster):
            def make(self, values):
                return dict(zip(self.attnames, values))

        t = register_composite('type_isd', self.conn, factory=DictComposite)

        self.assertEqual(t.name, 'type_isd')
        self.assertEqual(t.oid, oid)

        curs = self.conn.cursor()
        r = (10, 'hello', date(2011,1,2))
        curs.execute("select %s::type_isd;", (r,))
        v = curs.fetchone()[0]
        self.assert_(isinstance(v, dict))
        self.assertEqual(v['anint'], 10)
        self.assertEqual(v['astring'], "hello")
        self.assertEqual(v['adate'], date(2011,1,2))

    def _create_type(self, name, fields):
        curs = self.conn.cursor()
        try:
            curs.execute("drop type %s cascade;" % name)
        except psycopg2.ProgrammingError:
            self.conn.rollback()

        curs.execute("create type %s as (%s);" % (name,
            ", ".join(["%s %s" % p for p in fields])))
        if '.' in name:
            schema, name = name.split('.')
        else:
            schema = 'public'

        curs.execute("""\
            SELECT t.oid
            FROM pg_type t JOIN pg_namespace ns ON typnamespace = ns.oid
            WHERE typname = %s and nspname = %s;
            """, (name, schema))
        oid = curs.fetchone()[0]
        self.conn.commit()
        return oid


def skip_if_json_module(f):
    """Skip a test if a Python json module *is* available"""
    @wraps(f)
    def skip_if_json_module_(self):
        if extras.json is not None:
            return self.skipTest("json module is available")

        return f(self)

    return skip_if_json_module_

def skip_if_no_json_module(f):
    """Skip a test if no Python json module is available"""
    @wraps(f)
    def skip_if_no_json_module_(self):
        if extras.json is None:
            return self.skipTest("json module not available")

        return f(self)

    return skip_if_no_json_module_

def skip_if_no_json_type(f):
    """Skip a test if PostgreSQL json type is not available"""
    @wraps(f)
    def skip_if_no_json_type_(self):
        curs = self.conn.cursor()
        curs.execute("select oid from pg_type where typname = 'json'")
        if not curs.fetchone():
            return self.skipTest("json not available in test database")

        return f(self)

    return skip_if_no_json_type_

class JsonTestCase(ConnectingTestCase):
    @skip_if_json_module
    def test_module_not_available(self):
        from psycopg2cffi.extras import Json
        self.assertRaises(ImportError, Json(None).getquoted)

    @skip_if_json_module
    def test_customizable_with_module_not_available(self):
        from psycopg2cffi.extras import Json
        class MyJson(Json):
            def dumps(self, obj):
                assert obj is None
                return "hi"

        self.assertEqual(MyJson(None).getquoted(), "'hi'")

    @skip_if_no_json_module
    def test_adapt(self):
        from psycopg2cffi.extras import json, Json

        objs = [None, "te'xt", 123, 123.45,
            _u(b'\xc3\xa0\xe2\x82\xac'), ['a', 100], {'a': 100} ]

        curs = self.conn.cursor()
        for obj in enumerate(objs):
            self.assertEqual(curs.mogrify("%s", (Json(obj),)),
                extensions.QuotedString(json.dumps(obj)).getquoted())

    @skip_if_no_json_module
    def test_adapt_dumps(self):
        from psycopg2cffi.extras import json, Json

        class DecimalEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, Decimal):
                    return float(obj)
                return json.JSONEncoder.default(self, obj)

        curs = self.conn.cursor()
        obj = Decimal('123.45')
        dumps = lambda obj: json.dumps(obj, cls=DecimalEncoder)
        self.assertEqual(curs.mogrify("%s", (Json(obj, dumps=dumps),)),
                b"'123.45'")

    @skip_if_no_json_module
    def test_adapt_subclass(self):
        from psycopg2cffi.extras import json, Json

        class DecimalEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, Decimal):
                    return float(obj)
                return json.JSONEncoder.default(self, obj)

        class MyJson(Json):
            def dumps(self, obj):
                return json.dumps(obj, cls=DecimalEncoder)

        curs = self.conn.cursor()
        obj = Decimal('123.45')
        self.assertEqual(curs.mogrify("%s", (MyJson(obj),)),
                b"'123.45'")

    @skip_if_no_json_module
    def test_register_on_dict(self):
        from psycopg2cffi.extras import Json
        extensions.register_adapter(dict, Json)

        try:
            curs = self.conn.cursor()
            obj = {'a': 123}
            self.assertEqual(curs.mogrify("%s", (obj,)),
                    b"""'{"a": 123}'""")
        finally:
           del extensions.adapters[dict, extensions.ISQLQuote]


    def test_type_not_available(self):
        curs = self.conn.cursor()
        curs.execute("select oid from pg_type where typname = 'json'")
        if curs.fetchone():
            return self.skipTest("json available in test database")

        self.assertRaises(psycopg2.ProgrammingError,
            extras.register_json, self.conn)

    @skip_if_no_json_module
    @skip_before_postgres(9, 2)
    def test_default_cast(self):
        curs = self.conn.cursor()

        curs.execute("""select '{"a": 100.0, "b": null}'::json""")
        self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None})

        curs.execute("""select array['{"a": 100.0, "b": null}']::json[]""")
        self.assertEqual(curs.fetchone()[0], [{'a': 100.0, 'b': None}])

    @skip_if_no_json_module
    @skip_if_no_json_type
    def test_register_on_connection(self):
        extras.register_json(self.conn)
        curs = self.conn.cursor()
        curs.execute("""select '{"a": 100.0, "b": null}'::json""")
        self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None})

    @skip_if_no_json_module
    @skip_if_no_json_type
    def test_register_on_cursor(self):
        curs = self.conn.cursor()
        extras.register_json(curs)
        curs.execute("""select '{"a": 100.0, "b": null}'::json""")
        self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None})

    @skip_if_no_json_module
    @skip_if_no_json_type
    def test_register_globally(self):
        old = extensions.string_types.get(114)
        olda = extensions.string_types.get(199)
        try:
            new, newa = extras.register_json(self.conn, globally=True)
            curs = self.conn.cursor()
            curs.execute("""select '{"a": 100.0, "b": null}'::json""")
            self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None})
        finally:
            extensions.string_types.pop(new.values[0])
            extensions.string_types.pop(newa.values[0])
            if old:
                extensions.register_type(old)
            if olda:
                extensions.register_type(olda)

    @skip_if_no_json_module
    @skip_if_no_json_type
    def test_loads(self):
        json = extras.json
        loads = lambda x: json.loads(x, parse_float=Decimal)
        extras.register_json(self.conn, loads=loads)
        curs = self.conn.cursor()
        curs.execute("""select '{"a": 100.0, "b": null}'::json""")
        data = curs.fetchone()[0]
        self.assert_(isinstance(data['a'], Decimal))
        self.assertEqual(data['a'], Decimal('100.0'))

    @skip_if_no_json_module
    @skip_if_no_json_type
    def test_no_conn_curs(self):
        from psycopg2cffi._json import _get_json_oids
        oid, array_oid = _get_json_oids(self.conn)

        old = extensions.string_types.get(114)
        olda = extensions.string_types.get(199)
        loads = lambda x: extras.json.loads(x, parse_float=Decimal)
        try:
            new, newa = extras.register_json(
                loads=loads, oid=oid, array_oid=array_oid)
            curs = self.conn.cursor()
            curs.execute("""select '{"a": 100.0, "b": null}'::json""")
            data = curs.fetchone()[0]
            self.assert_(isinstance(data['a'], Decimal))
            self.assertEqual(data['a'], Decimal('100.0'))
        finally:
            extensions.string_types.pop(new.values[0])
            extensions.string_types.pop(newa.values[0])
            if old:
                extensions.register_type(old)
            if olda:
                extensions.register_type(olda)

    @skip_if_no_json_module
    @skip_before_postgres(9, 2)
    def test_register_default(self):
        curs = self.conn.cursor()

        loads = lambda x: extras.json.loads(x, parse_float=Decimal)
        extras.register_default_json(curs, loads=loads)

        curs.execute("""select '{"a": 100.0, "b": null}'::json""")
        data = curs.fetchone()[0]
        self.assert_(isinstance(data['a'], Decimal))
        self.assertEqual(data['a'], Decimal('100.0'))

        curs.execute("""select array['{"a": 100.0, "b": null}']::json[]""")
        data = curs.fetchone()[0]
        self.assert_(isinstance(data[0]['a'], Decimal))
        self.assertEqual(data[0]['a'], Decimal('100.0'))

    @skip_if_no_json_module
    @skip_if_no_json_type
    def test_null(self):
        extras.register_json(self.conn)
        curs = self.conn.cursor()
        curs.execute("""select NULL::json""")
        self.assertEqual(curs.fetchone()[0], None)
        curs.execute("""select NULL::json[]""")
        self.assertEqual(curs.fetchone()[0], None)

    @skip_if_no_json_module
    def test_no_array_oid(self):
        curs = self.conn.cursor()
        t1, t2 = extras.register_json(curs, oid=25)
        self.assertEqual(t1.values[0], 25)
        self.assertEqual(t2, None)

        curs.execute("""select '{"a": 100.0, "b": null}'::text""")
        data = curs.fetchone()[0]
        self.assertEqual(data['a'], 100)
        self.assertEqual(data['b'], None)


def skip_if_no_jsonb_type(f):
    return skip_before_postgres(9, 4)(f)

class JsonbTestCase(ConnectingTestCase):
    @staticmethod
    def myloads(s):
        import json
        rv = json.loads(s)
        rv['test'] = 1
        return rv

    def test_default_cast(self):
        curs = self.conn.cursor()

        curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
        self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None})

        curs.execute("""select array['{"a": 100.0, "b": null}']::jsonb[]""")
        self.assertEqual(curs.fetchone()[0], [{'a': 100.0, 'b': None}])

    def test_register_on_connection(self):
        extras.register_json(self.conn, loads=self.myloads, name='jsonb')
        curs = self.conn.cursor()
        curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
        self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1})

    def test_register_on_cursor(self):
        curs = self.conn.cursor()
        extras.register_json(curs, loads=self.myloads, name='jsonb')
        curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
        self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1})

    def test_register_globally(self):
        old = extensions.string_types.get(3802)
        olda = extensions.string_types.get(3807)
        try:
            new, newa = extras.register_json(self.conn,
                loads=self.myloads, globally=True, name='jsonb')
            curs = self.conn.cursor()
            curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
            self.assertEqual(curs.fetchone()[0], {'a': 100.0, 'b': None, 'test': 1})
        finally:
            extensions.string_types.pop(new.values[0])
            extensions.string_types.pop(newa.values[0])
            if old:
                extensions.register_type(old)
            if olda:
                extensions.register_type(olda)

    def test_loads(self):
        json = extras.json
        loads = lambda x: json.loads(x, parse_float=Decimal)
        extras.register_json(self.conn, loads=loads, name='jsonb')
        curs = self.conn.cursor()
        curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
        data = curs.fetchone()[0]
        self.assert_(isinstance(data['a'], Decimal))
        self.assertEqual(data['a'], Decimal('100.0'))
        # sure we are not manling json too?
        curs.execute("""select '{"a": 100.0, "b": null}'::json""")
        data = curs.fetchone()[0]
        self.assert_(isinstance(data['a'], float))
        self.assertEqual(data['a'], 100.0)

    def test_register_default(self):
        curs = self.conn.cursor()

        loads = lambda x: extras.json.loads(x, parse_float=Decimal)
        extras.register_default_jsonb(curs, loads=loads)

        curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
        data = curs.fetchone()[0]
        self.assert_(isinstance(data['a'], Decimal))
        self.assertEqual(data['a'], Decimal('100.0'))

        curs.execute("""select array['{"a": 100.0, "b": null}']::jsonb[]""")
        data = curs.fetchone()[0]
        self.assert_(isinstance(data[0]['a'], Decimal))
        self.assertEqual(data[0]['a'], Decimal('100.0'))

    def test_null(self):
        curs = self.conn.cursor()
        curs.execute("""select NULL::jsonb""")
        self.assertEqual(curs.fetchone()[0], None)
        curs.execute("""select NULL::jsonb[]""")
        self.assertEqual(curs.fetchone()[0], None)

decorate_all_tests(JsonbTestCase, skip_if_no_json_module)
decorate_all_tests(JsonbTestCase, skip_if_no_jsonb_type)


class RangeTestCase(unittest.TestCase):
    def test_noparam(self):
        from psycopg2cffi.extras import Range
        r = Range()

        self.assert_(not r.isempty)
        self.assertEqual(r.lower, None)
        self.assertEqual(r.upper, None)
        self.assert_(r.lower_inf)
        self.assert_(r.upper_inf)
        self.assert_(not r.lower_inc)
        self.assert_(not r.upper_inc)

    def test_empty(self):
        from psycopg2cffi.extras import Range
        r = Range(empty=True)

        self.assert_(r.isempty)
        self.assertEqual(r.lower, None)
        self.assertEqual(r.upper, None)
        self.assert_(not r.lower_inf)
        self.assert_(not r.upper_inf)
        self.assert_(not r.lower_inc)
        self.assert_(not r.upper_inc)

    def test_nobounds(self):
        from psycopg2cffi.extras import Range
        r = Range(10, 20)
        self.assertEqual(r.lower, 10)
        self.assertEqual(r.upper, 20)
        self.assert_(not r.isempty)
        self.assert_(not r.lower_inf)
        self.assert_(not r.upper_inf)
        self.assert_(r.lower_inc)
        self.assert_(not r.upper_inc)

    def test_bounds(self):
        from psycopg2cffi.extras import Range
        for bounds, lower_inc, upper_inc in [
                ('[)', True, False),
                ('(]', False, True),
                ('()', False, False),
                ('[]', True, True),]:
            r = Range(10, 20, bounds)
            self.assertEqual(r.lower, 10)
            self.assertEqual(r.upper, 20)
            self.assert_(not r.isempty)
            self.assert_(not r.lower_inf)
            self.assert_(not r.upper_inf)
            self.assertEqual(r.lower_inc, lower_inc)
            self.assertEqual(r.upper_inc, upper_inc)

    def test_keywords(self):
        from psycopg2cffi.extras import Range
        r = Range(upper=20)
        self.assertEqual(r.lower, None)
        self.assertEqual(r.upper, 20)
        self.assert_(not r.isempty)
        self.assert_(r.lower_inf)
        self.assert_(not r.upper_inf)
        self.assert_(not r.lower_inc)
        self.assert_(not r.upper_inc)

        r = Range(lower=10, bounds='(]')
        self.assertEqual(r.lower, 10)
        self.assertEqual(r.upper, None)
        self.assert_(not r.isempty)
        self.assert_(not r.lower_inf)
        self.assert_(r.upper_inf)
        self.assert_(not r.lower_inc)
        self.assert_(not r.upper_inc)

    def test_bad_bounds(self):
        from psycopg2cffi.extras import Range
        self.assertRaises(ValueError, Range, bounds='(')
        self.assertRaises(ValueError, Range, bounds='[}')

    def test_in(self):
        from psycopg2cffi.extras import Range
        r = Range(empty=True)
        self.assert_(10 not in r)

        r = Range()
        self.assert_(10 in r)

        r = Range(lower=10, bounds='[)')
        self.assert_(9 not in r)
        self.assert_(10 in r)
        self.assert_(11 in r)

        r = Range(lower=10, bounds='()')
        self.assert_(9 not in r)
        self.assert_(10 not in r)
        self.assert_(11 in r)

        r = Range(upper=20, bounds='()')
        self.assert_(19 in r)
        self.assert_(20 not in r)
        self.assert_(21 not in r)

        r = Range(upper=20, bounds='(]')
        self.assert_(19 in r)
        self.assert_(20 in r)
        self.assert_(21 not in r)

        r = Range(10, 20)
        self.assert_(9 not in r)
        self.assert_(10 in r)
        self.assert_(11 in r)
        self.assert_(19 in r)
        self.assert_(20 not in r)
        self.assert_(21 not in r)

        r = Range(10, 20, '(]')
        self.assert_(9 not in r)
        self.assert_(10 not in r)
        self.assert_(11 in r)
        self.assert_(19 in r)
        self.assert_(20 in r)
        self.assert_(21 not in r)

        r = Range(20, 10)
        self.assert_(9 not in r)
        self.assert_(10 not in r)
        self.assert_(11 not in r)
        self.assert_(19 not in r)
        self.assert_(20 not in r)
        self.assert_(21 not in r)

    def test_nonzero(self):
        from psycopg2cffi.extras import Range
        self.assert_(Range())
        self.assert_(Range(10, 20))
        self.assert_(not Range(empty=True))

    def test_eq_hash(self):
        from psycopg2cffi.extras import Range
        def assert_equal(r1, r2):
            self.assert_(r1 == r2)
            self.assert_(hash(r1) == hash(r2))

        assert_equal(Range(empty=True), Range(empty=True))
        assert_equal(Range(), Range())
        assert_equal(Range(10, None), Range(10, None))
        assert_equal(Range(10, 20), Range(10, 20))
        assert_equal(Range(10, 20), Range(10, 20, '[)'))
        assert_equal(Range(10, 20, '[]'), Range(10, 20, '[]'))

        def assert_not_equal(r1, r2):
            self.assert_(r1 != r2)
            self.assert_(hash(r1) != hash(r2))

        assert_not_equal(Range(10, 20), Range(10, 21))
        assert_not_equal(Range(10, 20), Range(11, 20))
        assert_not_equal(Range(10, 20, '[)'), Range(10, 20, '[]'))

    def test_not_ordered(self):
        from psycopg2cffi.extras import Range
        self.assertRaises(TypeError, lambda: Range(empty=True) < Range(0,4))
        self.assertRaises(TypeError, lambda: Range(1,2) > Range(0,4))
        self.assertRaises(TypeError, lambda: Range(1,2) <= Range())
        self.assertRaises(TypeError, lambda: Range(1,2) >= Range())


def skip_if_no_range(f):
    @wraps(f)
    def skip_if_no_range_(self):
        if self.conn.server_version < 90200:
            return self.skipTest(
                "server version %s doesn't support range types"
                % self.conn.server_version)

        return f(self)

    return skip_if_no_range_


class RangeCasterTestCase(ConnectingTestCase):

    builtin_ranges = ('int4range', 'int8range', 'numrange',
        'daterange', 'tsrange', 'tstzrange')

    def test_cast_null(self):
        cur = self.conn.cursor()
        for type in self.builtin_ranges:
            cur.execute("select NULL::%s" % type)
            r = cur.fetchone()[0]
            self.assertEqual(r, None)

    def test_cast_empty(self):
        from psycopg2cffi.extras import Range
        cur = self.conn.cursor()
        for type in self.builtin_ranges:
            cur.execute("select 'empty'::%s" % type)
            r = cur.fetchone()[0]
            self.assert_(isinstance(r, Range), type)
            self.assert_(r.isempty)

    def test_cast_inf(self):
        from psycopg2cffi.extras import Range
        cur = self.conn.cursor()
        for type in self.builtin_ranges:
            cur.execute("select '(,)'::%s" % type)
            r = cur.fetchone()[0]
            self.assert_(isinstance(r, Range), type)
            self.assert_(not r.isempty)
            self.assert_(r.lower_inf)
            self.assert_(r.upper_inf)

    def test_cast_numbers(self):
        from psycopg2cffi.extras import NumericRange
        cur = self.conn.cursor()
        for type in ('int4range', 'int8range'):
            cur.execute("select '(10,20)'::%s" % type)
            r = cur.fetchone()[0]
            self.assert_(isinstance(r, NumericRange))
            self.assert_(not r.isempty)
            self.assertEqual(r.lower, 11)
            self.assertEqual(r.upper, 20)
            self.assert_(not r.lower_inf)
            self.assert_(not r.upper_inf)
            self.assert_(r.lower_inc)
            self.assert_(not r.upper_inc)

        cur.execute("select '(10.2,20.6)'::numrange")
        r = cur.fetchone()[0]
        self.assert_(isinstance(r, NumericRange))
        self.assert_(not r.isempty)
        self.assertEqual(r.lower, Decimal('10.2'))
        self.assertEqual(r.upper, Decimal('20.6'))
        self.assert_(not r.lower_inf)
        self.assert_(not r.upper_inf)
        self.assert_(not r.lower_inc)
        self.assert_(not r.upper_inc)

    def test_cast_date(self):
        from psycopg2cffi.extras import DateRange
        cur = self.conn.cursor()
        cur.execute("select '(2000-01-01,2012-12-31)'::daterange")
        r = cur.fetchone()[0]
        self.assert_(isinstance(r, DateRange))
        self.assert_(not r.isempty)
        self.assertEqual(r.lower, date(2000,1,2))
        self.assertEqual(r.upper, date(2012,12,31))
        self.assert_(not r.lower_inf)
        self.assert_(not r.upper_inf)
        self.assert_(r.lower_inc)
        self.assert_(not r.upper_inc)

    def test_cast_timestamp(self):
        from psycopg2cffi.extras import DateTimeRange
        cur = self.conn.cursor()
        ts1 = datetime(2000,1,1)
        ts2 = datetime(2000,12,31,23,59,59,999)
        cur.execute("select tsrange(%s, %s, '()')", (ts1, ts2))
        r = cur.fetchone()[0]
        self.assert_(isinstance(r, DateTimeRange))
        self.assert_(not r.isempty)
        self.assertEqual(r.lower, ts1)
        self.assertEqual(r.upper, ts2)
        self.assert_(not r.lower_inf)
        self.assert_(not r.upper_inf)
        self.assert_(not r.lower_inc)
        self.assert_(not r.upper_inc)

    def test_cast_timestamptz(self):
        from psycopg2cffi.extras import DateTimeTZRange
        from psycopg2cffi.tz import FixedOffsetTimezone
        cur = self.conn.cursor()
        ts1 = datetime(2000,1,1, tzinfo=FixedOffsetTimezone(600))
        ts2 = datetime(2000,12,31,23,59,59,999, tzinfo=FixedOffsetTimezone(600))
        cur.execute("select tstzrange(%s, %s, '[]')", (ts1, ts2))
        r = cur.fetchone()[0]
        self.assert_(isinstance(r, DateTimeTZRange))
        self.assert_(not r.isempty)
        self.assertEqual(r.lower, ts1)
        self.assertEqual(r.upper, ts2)
        self.assert_(not r.lower_inf)
        self.assert_(not r.upper_inf)
        self.assert_(r.lower_inc)
        self.assert_(r.upper_inc)

    def test_adapt_number_range(self):
        from psycopg2cffi.extras import NumericRange
        cur = self.conn.cursor()

        r = NumericRange(empty=True)
        cur.execute("select %s::int4range", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, NumericRange))
        self.assert_(r1.isempty)

        r = NumericRange(10, 20)
        cur.execute("select %s::int8range", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, NumericRange))
        self.assertEqual(r1.lower, 10)
        self.assertEqual(r1.upper, 20)
        self.assert_(r1.lower_inc)
        self.assert_(not r1.upper_inc)

        r = NumericRange(Decimal('10.2'), Decimal('20.5'), '(]')
        cur.execute("select %s::numrange", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, NumericRange))
        self.assertEqual(r1.lower, Decimal('10.2'))
        self.assertEqual(r1.upper, Decimal('20.5'))
        self.assert_(not r1.lower_inc)
        self.assert_(r1.upper_inc)

    def test_adapt_numeric_range(self):
        from psycopg2cffi.extras import NumericRange
        cur = self.conn.cursor()

        r = NumericRange(empty=True)
        cur.execute("select %s::int4range", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, NumericRange), r1)
        self.assert_(r1.isempty)

        r = NumericRange(10, 20)
        cur.execute("select %s::int8range", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, NumericRange))
        self.assertEqual(r1.lower, 10)
        self.assertEqual(r1.upper, 20)
        self.assert_(r1.lower_inc)
        self.assert_(not r1.upper_inc)

        r = NumericRange(Decimal('10.2'), Decimal('20.5'), '(]')
        cur.execute("select %s::numrange", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, NumericRange))
        self.assertEqual(r1.lower, Decimal('10.2'))
        self.assertEqual(r1.upper, Decimal('20.5'))
        self.assert_(not r1.lower_inc)
        self.assert_(r1.upper_inc)

    def test_adapt_date_range(self):
        from psycopg2cffi.extras import DateRange, DateTimeRange, DateTimeTZRange
        from psycopg2cffi.tz import FixedOffsetTimezone
        cur = self.conn.cursor()

        d1 = date(2012, 1, 1)
        d2 = date(2012, 12, 31)
        r = DateRange(d1, d2)
        cur.execute("select %s", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, DateRange))
        self.assertEqual(r1.lower, d1)
        self.assertEqual(r1.upper, d2)
        self.assert_(r1.lower_inc)
        self.assert_(not r1.upper_inc)

        r = DateRange('2012-01-01', '2012-12-31')
        cur.execute("select %s", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, DateRange))
        self.assertEqual(r1.lower, d1)
        self.assertEqual(r1.upper, d2)
        self.assert_(r1.lower_inc)
        self.assert_(not r1.upper_inc)

        r = DateRange(_u(b'2012-01-01'), _u(b'2012-12-31'))
        cur.execute("select %s", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, DateRange))
        self.assertEqual(r1.lower, d1)
        self.assertEqual(r1.upper, d2)
        self.assert_(r1.lower_inc)
        self.assert_(not r1.upper_inc)

        r = DateTimeRange(empty=True)
        cur.execute("select %s", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, DateTimeRange))
        self.assert_(r1.isempty)

        ts1 = datetime(2000,1,1, tzinfo=FixedOffsetTimezone(600))
        ts2 = datetime(2000,12,31,23,59,59,999, tzinfo=FixedOffsetTimezone(600))
        r = DateTimeTZRange(ts1, ts2, '(]')
        cur.execute("select %s", (r,))
        r1 = cur.fetchone()[0]
        self.assert_(isinstance(r1, DateTimeTZRange))
        self.assertEqual(r1.lower, ts1)
        self.assertEqual(r1.upper, ts2)
        self.assert_(not r1.lower_inc)
        self.assert_(r1.upper_inc)

    def test_register_range_adapter(self):
        from psycopg2cffi.extras import Range, register_range
        cur = self.conn.cursor()
        cur.execute("create type textrange as range (subtype=text)")
        rc = register_range('textrange', 'TextRange', cur)

        TextRange = rc.range
        self.assert_(issubclass(TextRange, Range))
        self.assertEqual(TextRange.__name__, 'TextRange')

        r = TextRange('a', 'b', '(]')
        cur.execute("select %s", (r,))
        r1 = cur.fetchone()[0]
        self.assertEqual(r1.lower, 'a')
        self.assertEqual(r1.upper, 'b')
        self.assert_(not r1.lower_inc)
        self.assert_(r1.upper_inc)

        cur.execute("select %s", ([r,r,r],))
        rs = cur.fetchone()[0]
        self.assertEqual(len(rs), 3)
        for r1 in rs:
            self.assertEqual(r1.lower, 'a')
            self.assertEqual(r1.upper, 'b')
            self.assert_(not r1.lower_inc)
            self.assert_(r1.upper_inc)

    def test_range_escaping(self):
        from psycopg2cffi.extras import register_range
        cur = self.conn.cursor()
        cur.execute("create type textrange as range (subtype=text)")
        rc = register_range('textrange', 'TextRange', cur)

        TextRange = rc.range
        cur.execute("""
            create table rangetest (
                id integer primary key,
                range textrange)""")

        bounds = [ '[)', '(]', '()', '[]' ]
        ranges = [ TextRange(low, up, bounds[i % 4])
            for i, (low, up) in enumerate(zip(
                [None] + list(map(chr, range(1, 128))),
                list(map(chr, range(1,128))) + [None],
                ))]
        ranges.append(TextRange())
        ranges.append(TextRange(empty=True))

        errs = 0
        for i, r in enumerate(ranges):
            # not all the ranges make sense:
            # fun fact: select ascii('#') < ascii('$'), '#' < '$'
            # yelds... t, f! At least in en_GB.UTF-8 collation.
            # which seems suggesting a supremacy of the pound on the dollar.
            # So some of these ranges will fail to insert. Be prepared but...
            try:
                cur.execute("""
                    savepoint x;
                    insert into rangetest (id, range) values (%s, %s);
                    """, (i, r))
            except psycopg2.DataError:
                errs += 1
                cur.execute("rollback to savepoint x;")

        # ...not too many errors! in the above collate there are 17 errors:
        # assume in other collates we won't find more than 30
        self.assert_(errs < 30,
            "too many collate errors. Is the test working?")

        cur.execute("select id, range from rangetest order by id")
        for i, r in cur:
            self.assertEqual(ranges[i].lower, r.lower)
            self.assertEqual(ranges[i].upper, r.upper)
            self.assertEqual(ranges[i].lower_inc, r.lower_inc)
            self.assertEqual(ranges[i].upper_inc, r.upper_inc)
            self.assertEqual(ranges[i].lower_inf, r.lower_inf)
            self.assertEqual(ranges[i].upper_inf, r.upper_inf)

    def test_range_not_found(self):
        from psycopg2cffi.extras import register_range
        cur = self.conn.cursor()
        self.assertRaises(psycopg2.ProgrammingError,
            register_range, 'nosuchrange', 'FailRange', cur)

    def test_schema_range(self):
        cur = self.conn.cursor()
        cur.execute("create schema rs")
        cur.execute("create type r1 as range (subtype=text)")
        cur.execute("create type r2 as range (subtype=text)")
        cur.execute("create type rs.r2 as range (subtype=text)")
        cur.execute("create type rs.r3 as range (subtype=text)")
        cur.execute("savepoint x")

        from psycopg2cffi.extras import register_range
        ra1 = register_range('r1', 'r1', cur)
        ra2 = register_range('r2', 'r2', cur)
        rars2 = register_range('rs.r2', 'r2', cur)
        rars3 = register_range('rs.r3', 'r3', cur)

        self.assertNotEqual(
            ra2.typecaster.values[0],
            rars2.typecaster.values[0])

        self.assertRaises(psycopg2.ProgrammingError,
            register_range, 'r3', 'FailRange', cur)
        cur.execute("rollback to savepoint x;")

        self.assertRaises(psycopg2.ProgrammingError,
            register_range, 'rs.r1', 'FailRange', cur)
        cur.execute("rollback to savepoint x;")

decorate_all_tests(RangeCasterTestCase, skip_if_no_range)


def test_suite():
    return unittest.TestLoader().loadTestsFromName(__name__)

if __name__ == "__main__":
    unittest.main()

