""" 
   Copyright (C) 2001 PimenTech SARL (http://www.pimentech.net)

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Library General Public License as
   published by the Free Software Foundation; either version 2 of the
   License, or (at your option) any later version.

   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Library General Public License for more details.

   You should have received a copy of the GNU Library General Public
   License along with this library; see the file COPYING.LIB.  If not,
   write to the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
   Boston, MA 02111-1307, USA.  
"""

from common import *
from pgmlgraph import *

SERIALTYPE = 'integer'
KEYROW = 'key_row'

class AsTable(CoreTable):
	def __init__(self, as_name, table_name):
		CoreTable.__init__(self, as_name)
		self.table_name = table_name
		self.type = 'AsTable'
		self.constraints = Map('Constraints')
		self.sequence = "%s_id_%s_seq" % (table_name, table_name[:23-len(table_name)])

	def insert_constraint(self, constraint, property_name, property_value):
		if not self.constraints.has_key(constraint):
			self.constraints[constraint] = Map('Properties')
		self.constraints[constraint][property_name] = property_value

class PgtmlGraph(Graph):
	"""
	TODO:
	- complter les attributs du pgtml avec le pgml
	"""
	def __init__(self, pgmlbase_simplified):
		Graph.__init__(self, None)
		self.base = pgmlbase_simplified

	def _compare_attribute(self, a, b): 
		va = int(a['rank'])
		vb = int(b['rank'])
		if va == vb: return 0
		if va < vb: return -1
		return 1


	def _get_constraints(self, vtables, type):
		constraints = []
		for vtable in vtables:
			# check joints
			for ref in vtable.keys():
				if ref[:4] == 'ref_':
					if len(vtable[ref])>1:
						print "%s : Too much transitions " % vtable[ref]
						return
					tableref = vtable[ref].values()[0]
					constraints.append("%s.%s=%s.id_%s" % (vtable.name, ref, tableref, tableref.object.table_name))
			# check pgtml constraints
			for attr in vtable.object.constraints.values(): 
				if attr[type] != 'no':
					constraints.append("%s.%s%s" % (vtable.name, attr['name'], attr['value']))
		return constraints
	
	def get_select(self):
		attributes = []
		s_tables = []
		s_constraints = self._get_constraints(self.values(), 'select') # WHERE clause
		for vtable in self.values():
			# SELECT clause
			for attr in vtable.object.values():
				if attr.has_key('rank') and attr['select'] != 'no':
					attributes.append(attr)
			# FROM clause
			s_tables.append("%s %s" % (vtable.object.table_name, vtable.name))
			
		attributes.sort(self._compare_attribute)
		s_attributes = []
		for attr in attributes:
			if attr['as']:
				s_attributes.append("%s.%s AS %s" % (attr['table_name'], attr['name'], attr['as']))
			else:
				s_attributes.append("%s.%s" % (attr['table_name'], attr['name']))

		query = "SELECT %s\nFROM %s" % \
				(joinfields(s_attributes, ","), joinfields(s_tables, ","))
		if len(s_constraints):
			query = "%s WHERE %s;" % (query, joinfields(s_constraints, " AND "))
		return (query, s_constraints)

	def write_select_query(self, output):
		"Writes raw SQL select query to output (stdout)"
		output.write(self.get_select()[0])

	def get_insert(self):
		if self.name == '':
			stderr.write('ERROR : no view name in pgtml !')
			exit(0)

		attributes = []
		s_vars = []
		s_seq = []
		indexes = []
		func_name = "insert_" + self.name
		
		for vtable in self.values():
			for attr in vtable.object.values():
				if attr.has_key('rank') and attr['insert'] != 'no':
					indexes.append(attr['rank'])
					attributes.append(attr)
			s_seq.append("seq_%s" % vtable.name)

		indexes.sort()
		attributes.sort(self._compare_attribute)
		func_args = map(lambda attr:attr['type'], attributes)
		func_labelargs = map(lambda attr:attr['name'], attributes)

		query = """
		DROP FUNCTION %s(%s);
		CREATE FUNCTION %s(%s)
		RETURNS integer AS '
		DECLARE
			-- %s
		""" % (func_name, joinfields(func_args, ','), \
			   func_name, joinfields(func_args, ','), \
			   joinfields(func_labelargs, " "))
		
		for seq in s_seq:
			query = "%s\t%s %s;\n" % (query, seq, SERIALTYPE)
		
		query = "%sBEGIN\n" % query

		fake_seq = ''
		for vtable in self.values():
			isatablename = self.base[vtable.object.table_name].object.isa
			if isatablename:
				# ISA tables shares ids
				fake_seq = fake_seq + "\tseq_%s=seq_%s;\n" % \
						   (vtable.name, vtable['ref_%s' % isatablename].values()[0].name)
			else:
				query = "%s\tseq_%s=nextval(''%s''::text);\n" % \
						(query, vtable.name, vtable.object.sequence)
				
		query = query + fake_seq
		
		for vtable in self.values():
			query = "%s\tINSERT INTO %s (" % (query, vtable.object.table_name)
			s_attr = [ "id_%s" % vtable.object.table_name ]
			s_val = [ "seq_%s" % vtable.name ]
			# vars
			for attr in vtable.object.values():
				if attr.has_key('rank') and attr['insert'] != 'no':
					s_attr.append(attr['name'])
					s_val.append("$%s" % (indexes.index(attr['rank'])+1))
			# reference sequences
			for ref in vtable.keys():
				if ref[:4] == 'ref_':
					if len(vtable[ref])>1:
						print "%s : Too much transitions " % vtable[ref]
						return
					tableref = vtable[ref].values()[0]
					s_attr.append(ref)
					s_val.append("seq_%s" % tableref)
					
			query = "%s %s) values (%s); " % \
					(query, joinfields(s_attr, ","), joinfields(s_val, ","))

		query = """%s
			RETURN 1
		END;
		' LANGUAGE 'plpgsql';\n""" % query

		return (query, func_name, func_args, func_labelargs)
						 
	def write_insert_function(self, output):
		"Writes raw SQL insert query to output (stdout)"
		output.write(self.get_insert()[0])

	def get_update(self):
		"""
		DROP FUNCTION update_viewname([ranked_attribute,ranked_constraint]*);
		CREATE FUNCTION update_viewname([ranked_attribute,ranked_constraint]*) RETURNS INTEGER AS '
		DECLARE
		toids RECORD;
		BEGIN 
			FOR toids IN
			  SELECT [as_table_i.OID]+
			  FROM [table_i as_table_i","]+
			  WHERE [as_table_i.ref_table_j = as_table_j.id_table_j AND]*
			        [as_table_i.constraint_j->name constraint_j->value]*
			LOOP
			  [UPDATE as_table_i.OID SET [as_table_i.attr_j = ranked_attribute_k","]+]+
			END LOOP;
	    END; 
        ' LANGUAGE 'plpgsql';
		"""
		if self.name == '':
			stderr.write('ERROR : no view name in pgtml !')
			exit(0)

		attributes = []
		s_fields = []
		indexes = []
		s_tables = []
		func_name = "update_" + self.name

		s_constraints = self._get_constraints(self.values(), 'update') # WHERE clause
		
		for vtable in self.values():
			for attr in vtable.object.values() + vtable.object.constraints.values():
				if attr.has_key('rank') and attr['update'] != 'no':
					indexes.append(attr['rank'])
					attributes.append(attr)
			# SELECT 
			s_fields.append("%s.oid as %s_oid" % (vtable.name, vtable.name))
			# FROM 
			s_tables.append("%s %s" % (vtable.object.table_name, vtable.name))

		indexes.sort()
		attributes.sort(self._compare_attribute)
		func_args = map(lambda attr:attr['type'], attributes)
		func_labelargs = map(lambda attr:attr['name'], attributes)

		query = """
		DROP FUNCTION %s(%s);
		CREATE FUNCTION %s(%s)
		RETURNS integer AS '
		DECLARE
			-- %s
			toid RECORD;
			BEGIN
			FOR toid IN
				SELECT %s
				FROM %s
				WHERE %s
			LOOP
		""" % (func_name, joinfields(func_args, ','), \
			   func_name, joinfields(func_args, ','), \
			   joinfields(func_labelargs, " "),
			   joinfields(s_fields, ","),
			   joinfields(s_tables, ","),
			   joinfields(s_constraints, " AND "))
		
		for vtable in self.values():
			s_var = []
			for attr in vtable.object.values():
				if attr.has_key('rank') and attr['update'] != 'no':
					s_var.append("%s=$%s" % (attr['name'],(indexes.index(attr['rank'])+1)))
			if len(s_var):
				query = "%s\tUPDATE %s SET %s WHERE oid=toid.%s_oid;" % \
						(query, vtable.object.table_name, joinfields(s_var, ","), vtable.name)
		query = """%s
		END LOOP;
		RETURN 1;
		END;
		' LANGUAGE 'plpgsql';\n""" % query

		return (query, func_name, func_args, func_labelargs)

	def write_update_function(self, output):
		"Writes raw SQL update query to output (stdout)"
		output.write(self.get_update()[0])

	class PgtmlHandler(CommonHandler):
		def __init__(self, name, graph):
			CommonHandler.__init__(self, name)
			self.graph = graph
			self.base = graph.base
			self.currentTable = None
			self.currentRelation = None			
			
		def startElement(self, name, attrs):
			CommonHandler.startElement(self, name, attrs)
			
			valueOf = Map("ValueOf")
			for attr in attrs:
				valueOf[attr]=attrs[attr]

			if name == 'view':
				self.graph.name = valueOf['name']
				
			if name == 'table':
				if self.graph[valueOf['name']]:
					self.currentTable = self.graph[valueOf['name']].object
				else:
					self.currentTable = AsTable(valueOf['as'], valueOf['name']) # Att ! on prend le param 'as' pour le nom
					self.graph.insert(self.currentTable)

			elif name == 'joint':
				t1 = self.graph[valueOf['from']].object
				t2 = self.graph[valueOf['to']].object
				if not t1 or not t2:
					error
				t1_base = self.base[t1.table_name].object
				through_rel = self.base[valueOf['through']].object
				t2_base = self.base[t2.table_name].object

				for r in ("ref_%s" % t2.table_name, "ref_%s_%s" % (valueOf['through'], t2.table_name)):
					#print "checking relation %s in %s" % (r,t1.table_name)
					if t1_base.has_key(r):
						#self.stderr.write("add edge (%s,%s,%s)" % (t1.name, r, t2.name))
						self.graph.insert_edge(t1, r, t2)
				for r in ("ref_%s" % t1.table_name, "ref_%s_%s" % (valueOf['through'], t1.table_name)):
					#print "checking relation %s in %s" % (r,t2.table_name)
					if t2_base.has_key(r):
						#self.stderr.write('add edge (%s,%s,%s)' % (t2.name, r, t1.name))
						self.graph.insert_edge(t2, r, t1)

			elif name == 'attribute' or name == 'constraint':
				if self.parentTag == 'table':
					attr_name = valueOf['name']
					pgml_attr = self.base[self.currentTable.table_name].object[attr_name]

					if name == 'attribute':
						for map in (pgml_attr, valueOf):
							if map: 
								for key,value in map.items():
									self.currentTable.insert_attribute(attr_name, key, value)
							else: # attr not found in pgml
								if attr_name[:3] == 'id_': # id's not defined in pgml graph
									self.currentTable.insert_attribute(attr_name, 'type', SERIALTYPE)
								else:
									# FATAL ERROR
									usage('attribute ERROR : %s attribute, table %s, is not present in pgml' \
										  % (attr_name, self.currentTable.table_name))
						self.currentTable.insert_attribute(attr_name, 'table_name', self.currentTable.name)

					else: # constraint
						if not valueOf['value'] and valueOf['rank']:
							# Attention ! S'il y a des trous dans les rank, a va faire des dcallages
							valueOf['value'] = '=$%s' % (int(valueOf['rank'])+1) 
						for map in (pgml_attr, valueOf):
							if map: 
								for key,value in map.items():
									self.currentTable.insert_constraint(attr_name, key, value)
							else: # attr not found in pgml
								if attr_name[:3] == 'id_': # id's not defined in pgml graph
									self.currentTable.insert_constraint(attr_name, 'type', SERIALTYPE)
								else:
									# FATAL ERROR
									usage('constraint ERROR : %s attribute, table %s, is not present in pgml' \
										  % (attr_name, self.currentTable.table_name))
						self.currentTable.insert_constraint(attr_name,'table_name', self.currentTable.name)				

			
		def endElement(self, name):
			CommonHandler.endElement(self, name)
			if name == 'table':
				self.currentTable = None


	
	def read(self, filename):
		pf=saxexts.ParserFactory()
		p=pf.make_parser('xml.sax.drivers.drv_xmlproc')
		p.setDocumentHandler(self.PgtmlHandler('doc_handler', self))
		p.setDTDHandler(self.PgtmlHandler('dtd_handler', self))
		p.setErrorHandler(self.PgtmlHandler('err_handler', self))
		p.setEntityResolver(self.PgtmlHandler('ent_handler', self))
		p.parse(filename)
