#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>

#include "comm.h"
#include "model.h"
#include "packets.h"
#include "global.h"

#define EXTRA_VERSION "-1"
#define COMM_VERSION "$Revision: 1.15 $"
#define CDEBUG(...) DEBUG(DCOMM,"Communication",__VA_ARGS__)


/* create a revision string for the server */
char *create_version (void) {
  static char *comm_version = NULL;

  if (comm_version == NULL) {
    comm_version = malloc((strlen(MODEL_VERSION) + 
                           strlen(COMM_VERSION)  +
                           strlen(EXTRA_VERSION)  +
                           1)
                          * sizeof(char));
   strcpy(comm_version,MODEL_VERSION);
   strcpy(comm_version + (strlen(MODEL_VERSION) * sizeof(char)),COMM_VERSION);
   strcpy(comm_version + ((strlen(MODEL_VERSION) + 
                           strlen(COMM_VERSION)) * sizeof(char)),EXTRA_VERSION);
  }
  return comm_version;
}

static int
send_packet(void *data, void *user_data) {
  CDEBUG("Sending packet of type %d",
           packet_get_type((Network_packet*)user_data));
  network_connection_sent((Network_connection *)data,
                          (Network_packet*)user_data);
  return TRUE;
}

typedef int (*cs_packet_handler) (Comm_server *serv,
                                  Comm_server_client *client,
                                  Network_packet *p);


int
handle_serv_connection(Comm_server *serv,cs_packet_handler func,
                       int state,int packets,int error_is_fatal) {
  int ret = 0;
  List_ptr *ptr;
  Comm_server_client *c;
  Network_packet *p;

  ptr = new_list_ptr(serv->clients);
  if (!list_ptr_first(ptr)) {
    del_list_ptr(ptr);
    return 0;
  }
  do {
    c = list_ptr_get_data(ptr);
    if (c->state != state) {
     continue;
    }
    if (network_connection_state(c->con) != NW_OK) {
      WARN("Network connection to %s failed in state %d",
           network_connection_get_sremote(c->con),
           c->state);
      comm_server_del_client(serv,c);
      del_comm_server_client(c);
      if (error_is_fatal) {
        del_list_ptr(ptr);
        return -1;
      }
      break;
    }
    if (c->state != state) continue;
    while (c->state == state && (p = network_connection_pop(c->con)) != NULL) {
      if (packet_get_type(p) & packets) {
        if (!(func)(serv,c,p)) {
          CDEBUG("Packet handler for state %d failed on %s",
                 state,
                 network_connection_get_sremote(c->con));
          comm_server_del_client(serv,c);
          del_comm_server_client(c);
          del_network_packet(p);
          if (error_is_fatal) {
            del_list_ptr(ptr);
            return -1;
          }
          break;
        }
        del_network_packet(p);
      } else {
        del_network_packet(p);
        WARN("Wrong packet for state %d from %s",
              state,
              network_connection_get_sremote(c->con));
        comm_server_del_client(serv,c);
        del_comm_server_client(c);
        if (error_is_fatal) {
          del_list_ptr(ptr);
          return -1;
        }
        break;
      }
    }
    if (c->state == state) ret++;
  } while (list_ptr_next(ptr));

  del_list_ptr(ptr);
  return ret;
}

Comm_server_client *
new_comm_server_client(Network_connection *con) {
  Comm_server_client *result;
  result = malloc(sizeof(Comm_server_client));
  assert(result != NULL);
  result->con = con;
  result->state = COMM_STATE_VERSION;
  result->hostname = NULL;
  result->remotes = NULL;
  return result;
}

void
del_comm_server_client(Comm_server_client *client) {
  if (client->con != NULL) del_network_connection(client->con);
  if (client->remotes != NULL) del_list(client->remotes);
  free(client->hostname);
  free(client);
}

Comm_server *new_comm_server(int port,int nrplayers) {
  Comm_server *result;
  Network_listener *listener;
  listener = new_network_listener(port,AF_UNSPEC);
  if (listener == NULL) 
    return NULL;

  result = malloc(sizeof(Comm_server));
  assert(result != NULL);
  result->players = malloc(sizeof(char *) *nrplayers);
  result->registered_players = 0;
  result->nrplayers = nrplayers;
  result->conns = new_list();
  result->clients = new_list();
  result->listeners = new_list();
  list_append(result->listeners,listener);

  return result;
}

void 
del_comm_server(Comm_server *serv) {
  Comm_server_client *client;
  Network_listener *l;

  free(serv->players);
  while ((client = list_pop(serv->clients)) != NULL) {
    del_comm_server_client(client);
  }
  while ((l = list_pop(serv->listeners)) != NULL) {
    del_network_listener(l);
  }
  del_list(serv->conns);
  del_list(serv->clients);
  del_list(serv->listeners);
  free(serv);
}

void 
comm_server_add_client(Comm_server *serv, Comm_server_client *client) {
  list_append(serv->clients,client);
  list_append(serv->conns,client->con);
}

void
comm_server_del_client(Comm_server *serv,Comm_server_client *client) {
  list_del(serv->clients,client);
  list_del(serv->conns,client->con);
}

int comm_server_register_player(Comm_server *serv, char *name) {
  if (serv->nrplayers == serv->registered_players) {
    INFO("Couldn't register %s, server full",name);
    return -1;
  }
  serv->players[serv->registered_players] = strdup(name);
  serv->registered_players++;
  INFO("registered player %s (%d)",name, serv->registered_players -1 );
  return serv->registered_players -1;
}

static int
cs_grab_connection(void *data, void *user_data) {
  Network_listener *l = (Network_listener *) data;
  Comm_server *serv = (Comm_server *)user_data;
  Network_connection *con;

  while ((con = network_listener_pop(l)) != NULL) {
    INFO("Network connection from %s",network_connection_get_sremote(con));
    comm_server_add_client(serv,new_comm_server_client(con));
  }
  return TRUE;
}

int cs_handle_version( Comm_server *serv, Comm_server_client *client,
                       Network_packet *p) {
  Network_packet *r;
  if (!strcmp(packet_client_info_get_version(p),create_version())) {
    client->state = COMM_STATE_PLAYREQ;
    r = new_packet_reply(TRUE);
    network_connection_sent(client->con,r);
    del_network_packet(r);
  } else {
    WARN("Client sent version %s, but we got %s",
        packet_client_info_get_version(p),create_version());
    return FALSE;
  }
  return TRUE;
}

int cs_handle_preq( Comm_server *serv, Comm_server_client *client,
                       Network_packet *p) {
  Network_packet *r;
  switch (packet_get_type(p)) {
    case PACKET_PLAYER_REQ:
      r = new_packet_reply(
           comm_server_register_player(serv,packet_player_req_get_name(p))
                          );
      network_connection_sent(client->con,r);
      del_network_packet(r);
      break;
    case PACKET_SYNC:
      client->state = COMM_STATE_CCONNREQ;
      break;
  }
  return TRUE;
}

int 
comm_server_wait_for_others(Comm_server *serv) {
  Network_listener *l;
  int ver,preq;
  while (serv->nrplayers != serv->registered_players) {
    do {
      network_update(serv->listeners,serv->conns,0);
      list_foreach(serv->listeners,cs_grab_connection,serv);

      ver = handle_serv_connection(serv,cs_handle_version, COMM_STATE_VERSION,
                                   PACKET_CLIENT_INFO,FALSE);
      preq = handle_serv_connection(serv,cs_handle_preq, COMM_STATE_PLAYREQ,
                                    PACKET_PLAYER_REQ|PACKET_SYNC,TRUE);
      if (preq < 0) {
        return FALSE;
      }
    } while(ver + preq > 0);
  }
  while ((l = list_pop(serv->listeners)) != NULL) {
    del_network_listener(l);
  }
  CDEBUG("Player registration done");
  return TRUE;
}

int
cs_do_add_connreq(void *data, void *user_data) {
  Comm_server_client *client = (Comm_server_client *)data;
  List *l = (List *)user_data;
  if (client->remotes == NULL) {
    list_append(l,client);
  }
  return TRUE;
}

void cs_send_connreq(Comm_server_client *client) {
  Comm_server_client *remote;
  Network_packet *p;
  remote = list_pop(client->remotes);
  if (remote == NULL) {
    client->state = COMM_STATE_PLAYINFO;
  } else {
    p = new_packet_conn_req(remote->hostname,remote->port);
    network_connection_sent(client->con,p);
    del_network_packet(p);
    list_prepend(client->remotes,remote);
    client->state = COMM_STATE_SCONNREQ;
  }
}

int
cs_do_tell_remotes(void *data, void *user_data) {
  Comm_server_client *client = (Comm_server_client *)data;
  Comm_server *serv = (Comm_server *)user_data;

  client->remotes = new_list();
  list_foreach(serv->clients,cs_do_add_connreq,client->remotes);
  cs_send_connreq(client);
  return TRUE;
}

int cs_handle_cconnreq( Comm_server *serv, Comm_server_client *client,
                       Network_packet *p) {
  client->hostname = strdup(packet_conn_req_gethost(p));
  client->port = packet_conn_req_getport(p);
  client->state = COMM_STATE_WCONNREQ;
  CDEBUG("Can connect on %s port %d",client->hostname,client->port);
  return TRUE;
}
int cs_handle_sconnreq( Comm_server *serv, Comm_server_client *client,
                       Network_packet *p) {
  CDEBUG("cs_handle_sconnreq");
  if (packet_reply_get_reply(p)) {
    list_pop(client->remotes);
    cs_send_connreq(client);
    return TRUE;
  }
  /* FIXME report nicer error */
  WARN("Connection between two clients failed");
  return FALSE;
}

int comm_server_send_connections(Comm_server *serv) {
  int ret;
  Network_packet *p;
  /* wait for all clients to tell us which port/hostname they want the
   * connection on */
  while ((ret = 
            handle_serv_connection(serv,cs_handle_cconnreq,COMM_STATE_CCONNREQ,
                                   PACKET_CONN_REQ,TRUE)) > 0)  {
     network_update(serv->listeners,serv->conns,0);
  }
  if (ret < 0) {
    return FALSE;
  }
  /* All clients are in WCONNREQ state */ 
  CDEBUG("Telling clients to connect together");
  list_foreach(serv->clients,cs_do_tell_remotes,serv);
  while ((ret = 
            handle_serv_connection(serv,cs_handle_sconnreq,COMM_STATE_SCONNREQ,
                                   PACKET_REPLY,TRUE)) > 0) {
    network_update(serv->listeners,serv->conns,0);
  }
  if (ret < 0) {
    return FALSE;
  }
  p = new_packet_sync();
  list_foreach(serv->conns,send_packet,p);
  del_network_packet(p);
  CDEBUG("Connections between clients succeeded");
  return TRUE;
}

int 
comm_server_send_player_info(Comm_server *serv) {
  Network_packet *p;
  int x;

  for (x = 0; x < serv->nrplayers; x++) {
    p = new_packet_player_info(x,serv->players[x]);
    list_foreach(serv->conns,send_packet,p);
    del_network_packet(p);
  }
  p = new_packet_sync();
  list_foreach(serv->conns,send_packet,p);
  del_network_packet(p);
  network_update(NULL,serv->conns,0);
  return TRUE;
}

int comm_server_send_level(Comm_server *serv,char *level) {
  int len = 0;
  int clen = 0;
  Network_packet *p;

  do { 
    p = new_packet_level_info(level+clen,&len);
    clen += len;
    list_foreach(serv->conns,send_packet,p);
    del_network_packet(p);
  } while (len > 0);
  p = new_packet_sync();
  list_foreach(serv->conns,send_packet,p);
  del_network_packet(p);
  network_update(NULL,serv->conns,0);
  return TRUE;
}

char *comm_server_get_name(Comm_server *serv,int player) {
  return serv->players[player];
}

Comm_game *comm_server_finalize(Comm_server *serv) {
  Comm_game *result;
  Comm_server_client *client;
  result = new_comm_game(serv->nrplayers);
  while ((client = list_pop(serv->clients)) != NULL) {
    list_append(result->connections,client->con);
    client->con = NULL;
    del_comm_server_client(client);
  }
  del_comm_server(serv);
  return result;
}

static Network_packet *
cc_get_server_packet(Comm_client *client,int packettype) {
  Network_packet *p;

  while ((p = network_connection_pop(client->server)) == NULL) {
    if (network_connection_state(client->server) != NW_OK) {
      return NULL;
    }
    network_update(client->listeners,client->connections,0);
  }
  if (!(packet_get_type(p) & packettype)) {
    WARN("Got packet of type %X, when expecting %X",
         packet_get_type(p),packettype);
    del_network_packet(p);
    return NULL;
  }
  return p;
}

Comm_client *new_comm_client(char *host,int port) {
  Comm_client *result;
  Network_connection *c;
  Network_packet *p;

  c = new_network_connection(host,port);
  if (c == NULL) {
    WARN("Network connection to server failed");
    return FALSE;
  }
  result = malloc(sizeof(Comm_client));
  result->listener = NULL;
  result->listeners = new_list();
  result->players = new_list();
  result->server = c;
  result->connections = new_list();
  result->level = NULL;
  result->levellen = 0;
  list_append(result->connections,c);

  p = new_packet_client_info(create_version());
  network_connection_sent(c,p);
  del_network_packet(p);

  p = cc_get_server_packet(result,PACKET_REPLY);
  if (p == NULL || !packet_reply_get_reply(p)) {
    WARN("Version mismatch between the server and this client");
    if (p!=NULL) del_network_packet(p);
    del_comm_client(result);
    return NULL;
  }
  del_network_packet(p);
  return result;
}

void 
del_comm_client(Comm_client *client) {
  Network_listener *l;
  Network_connection *c;
  Comm_client_player *p;

  while ((l = list_pop(client->listeners)) != NULL) {
    del_network_listener(l);
  }
  while ((c = list_pop(client->connections)) != NULL) {
    del_network_connection(c);
  }
  while ((p = list_pop(client->players)) != NULL) {
    free(p->name);
    free(p);
  }
  del_list(client->listeners);
  del_list(client->connections);
  del_list(client->players);
  free(client->level);
  free(client);
}

int comm_client_register_player(Comm_client *client, char *name) {
  Network_packet *p;
  int ret;
  CDEBUG("Sending player request");
  p = new_packet_player_req(name);
  network_connection_sent(client->server,p);
  del_network_packet(p);
  p = cc_get_server_packet(client,PACKET_REPLY); 
  if (p== NULL) { 
    return -2;
  }
  ret =  packet_reply_get_reply(p);
  del_network_packet(p);
  return ret;
}

int comm_client_register_done(Comm_client *client) {
  Network_packet *p;
  p = new_packet_sync();
  network_connection_sent(client->server,p);
  del_network_packet(p);
  network_update(NULL,client->connections,0);
  return TRUE;
}

int comm_client_wait_for_connections(Comm_client *client,char *host,int port) {
  Network_listener *listener = NULL;
  Network_connection *con;
  Network_packet *p,*r;

  do {
    CDEBUG("Trying to listen on port %d",port);
    listener = new_network_listener(port,AF_UNSPEC);
  }  while (listener == NULL && port++);
  list_append(client->listeners,listener);

  if (host == NULL) {
    host = network_connection_get_local(client->server);
  }
 
  r = new_packet_conn_req(host,port);
  network_connection_sent(client->server,r);
  del_network_packet(r);

  for (;;) {
    p = cc_get_server_packet(client,PACKET_CONN_REQ|PACKET_SYNC);
    if (p == NULL) return FALSE;
    if (packet_get_type(p) == PACKET_SYNC) {
      del_network_packet(p);
      break;
    } else {
      CDEBUG("Request to connect to %s at %d",
              packet_conn_req_gethost(p),
              packet_conn_req_getport(p)
            );
      con = new_network_connection(packet_conn_req_gethost(p),
                                   packet_conn_req_getport(p)
                                   );
      r = new_packet_reply(con != NULL);
      network_connection_sent(client->server,r);
      if (con!=NULL) list_append(client->connections,con);
      del_network_packet(p);
      del_network_packet(r);
    }
  }
  while ((con = network_listener_pop(listener)) != NULL) {
    list_append(client->connections,con);
  }
  del_network_listener(list_pop(client->listeners));
  return TRUE;
}

int comm_client_get_player_info(Comm_client *client) {
  Network_packet *p;
  Comm_client_player *player;

  for (;;) {
    p = cc_get_server_packet(client,PACKET_SYNC|PACKET_PLAYER_INFO);
    if (p == NULL) return FALSE;
    if (packet_get_type(p) == PACKET_SYNC) {
        del_network_packet(p);
        return TRUE;
    } else {
      player = malloc(sizeof(Comm_client_player));
      player->num = packet_player_info_get_number(p);
      player->name = strdup(packet_player_info_get_name(p));
      list_append(client->players,player);
      del_network_packet(p);
    }
  }
}

int 
comm_client_get_nrplayers(Comm_client *client) {
  return list_length(client->players);
}

static int
search_client_player(void *data, void *user_data) {
  return ((Comm_client_player *)data)->num == *((int32_t *)user_data);
}

char *comm_client_get_name(Comm_client *client,int player) {
   Comm_client_player *p;
   p = list_search(client->players,search_client_player,&player);
   return p == NULL ? NULL : p->name;
}

int comm_client_get_level(Comm_client *client) {
  /* FIXME make this code somewhat nicer */
  Network_packet *p;

  client->level = malloc(sizeof(char));
  for (;;) {
    p = cc_get_server_packet(client,PACKET_SYNC|PACKET_LEVEL_INFO);
    if (p == NULL) return FALSE;
    if (packet_get_type(p) == PACKET_SYNC) {
        del_network_packet(p);
        *(client->level + client->levellen) = '\0'; 
        return TRUE;
    } else {
      client->level = realloc(client->level,(client->levellen +
           strlen(packet_level_info_get_data(p))) * sizeof(char) +1);
      strcpy(client->level + client->levellen,
             packet_level_info_get_data(p));
      client->levellen += strlen(packet_level_info_get_data(p));

      del_network_packet(p);
    }
  }
  return TRUE;
}

char *comm_client_get_levelstr(Comm_client *client) {
  return client->level;
}

Comm_game *comm_client_finalize(Comm_client *client) {
  Comm_game *result;
  Network_connection *con;
  result = new_comm_game(list_length(client->players));
  while ((con = list_pop(client->connections)) != NULL) {
    list_append(result->connections,con);
  }
  del_comm_client(client);
  return result;
}

Comm_game * new_comm_game(int nrplayers) {
  Comm_game *result;
  CDEBUG("Creating a game for %d players",nrplayers);
  result = malloc(sizeof(Comm_game));
  assert(result != NULL);
  memset(result,0,sizeof(Comm_game));

  result->connections = new_list();
  result->synced_connections = new_list();

  result->events = malloc(sizeof(uint32_t) * nrplayers);
  memset(result->events,0,sizeof(uint32_t) * nrplayers);
  return result;
}

void del_comm_game(Comm_game *g) {
  Network_connection *c;

  while ((c= list_pop(g->connections)) != NULL) {
    del_network_connection(c);
  }
  del_list(g->connections);
  del_list(g->synced_connections);

  free(g->events);
  free(g);
}

void 
comm_game_send_event(Comm_game *g,int player,uint32_t event) {
  Network_packet *p;
  if (event != g->events[player]) {
     g->events[player] = event;
     if (g->connections != NULL) {
       p = new_packet_action(player,event);
       list_foreach(g->connections,send_packet,p);
       del_network_packet(p);
       network_update(NULL,g->connections,1);
     }
  }
}

uint32_t comm_game_get_event(Comm_game *g,int player) {
  return g->events[player];
}

static int
handle_game_events(void *data, void *user_data) {
  Comm_game *g = (Comm_game *)user_data;
  Network_connection *c = (Network_connection *)data;
  Network_packet *p;

  if (network_connection_state(c) != NW_OK) {
    WARN("Network connection lost");
    return FALSE;
  }
  while((p = network_connection_pop(c)) != NULL) {
    CDEBUG("Got packet of type %d",packet_get_type(p));
    switch(packet_get_type(p)) {
      case PACKET_ACTION:
        g->events[packet_action_get_player(p)] = packet_action_get_event(p);
        del_network_packet(p);
        break;
      case PACKET_SYNC:
        list_del(g->connections,c);
        list_append(g->synced_connections,c);
        del_network_packet(p);
        return TRUE;
      default:
        CDEBUG("Wrong packet for this state %d",packet_get_type(p));
        del_network_packet(p);
        return FALSE;
      }
  }
  return TRUE;
}

int 
comm_game_update(Comm_game *g) {
  List *tmp;
  Network_packet *p;
  int ret = TRUE;

  p = new_packet_sync();
  list_foreach(g->connections,send_packet,p);
  del_network_packet(p);
  while (ret == TRUE && list_length(g->connections) != 0) {
    network_update(NULL,g->connections,0);
    ret = list_foreach(g->connections,handle_game_events,g);
  }
  tmp = g->connections;
  g->connections = g->synced_connections;
  g->synced_connections = tmp;
  return ret;
}
