/* $Id: agentrex.C,v 1.9 2001/01/29 21:39:06 ericp Exp $ */

/*
 *
 * Copyright (C) 2000 Michael Kaminsky (kaminsky@lcs.mit.edu)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "aios.h"
#include "agent.h"
#include "agentconn.h"
#include "sfsconnect.h"
#include "rex_prot.h"
#include "rex.h"



class agentstartfd: public rexfd
{

  str schost;
  cbv succeedcb;

  void
  agentstarted (ref<int> resp, clnt_stat err)
  {
    if(*resp || err)
      warn << "could not start agent on "<< schost << " : " << strerror (err) << "\n";
    else
      warn << "agent forwarding connection started\n";
  }


  
public:

  agentstartfd (cbv succeedcb, str schost, u_int32_t channo, int fd, ptr <aclnt> pr, fdkeeper &fdk):
    rexfd (channo, fd, pr, fdk), schost (schost), succeedcb (succeedcb) {}
    
  virtual void
  newfd (svccb *sbp)
  {

    succeedcb ();

    rexcb_newfd_arg *argp = sbp->template getarg<rexcb_newfd_arg> ();
    
    int s[2];

    if(socketpair(AF_UNIX, SOCK_STREAM, 0, s)) {
      warn << "error creating socketpair for agent forwarding";
      sbp->replyref(false);
      return;
    }

    make_async (s[1]);
    make_async (s[0]);

    sfsagent *a = New sfsagent (s[1]);
    a->setname (schost);
    a->cs = NULL;

    ref <int> resp = new refcounted <int>;
    a->ac->call (AGENT_START, NULL, resp, wrap (this, &agentstartfd::agentstarted, resp)); 

    vNew unixfd (channo, argp->newfd, proxy, fdk, s[0]);

    sbp->replyref (true);
  } 
};


class agentchannel: public rexchannel
{
  str schost;
  cbv succeedcb;
public:
  agentchannel (vec<str> command, str schost, cbv succeedcb) :
    rexchannel (1, command), schost (schost), succeedcb (succeedcb)
    {}

  void
    madechannel ()
  {
    vNew agentstartfd (succeedcb, schost, channo, 0, proxy, *this);
  }
};






class rexsess {
  ptr<sfscon> sessconn;
  ptr<aclnt> sessclnt;
  ptr<aclnt> sfsclnt;
  ptr<aclnt> rexclnt;
  sfs_sessinfo sessinfo;
  rex_sesskeydat kscdat;
  rex_sesskeydat kcsdat;
  u_int32_t myauthno;
  sfs_seqno seqno;

  rexsession *sess;

  void fail ();
  void attached (rexd_attach_res *resp, clnt_stat err);
  void attach ();
  void spawned (rexd_spawn_res *resp, clnt_stat err);
  void spawn ();
  void loggedin (sfs_loginres *lresp, clnt_stat err);
  void dologin (ptr<sfsagent_auth_res> ares, clnt_stat err);
  ptr<sfsagent_auth_res> signauthreq (sfsagent_authinit_arg *aa);
  void connected (ptr<sfscon> sc, str err);
  void seq2sessinfo (u_int64_t seqno, sfs_hash *sidp, sfs_sessinfo *sip);
  void eof () { delete this; }

public:
  str path;
  cb_rex::ptr cb;
  ptr<sfsagent_rex_res> cbres;
  ihash_entry<rexsess> link;

  void succeed (cb_rex::ptr cb);

  rexsess (str path, cb_rex::ptr cb);
  ~rexsess ();
};

ihash<str, rexsess, &rexsess::path, &rexsess::link> sesstab;

void
rexsess::seq2sessinfo (u_int64_t seqno, sfs_hash *sidp, sfs_sessinfo *sip)
{
  kcsdat.seqno = seqno;
  kscdat.seqno = seqno;

  sfs_sessinfo si;
  si.type = SFS_SESSINFO;
  si.kcs.setsize (sha1::hashsize);
  sha1_hashxdr (si.kcs.base (), kcsdat, true);
  si.ksc.setsize (sha1::hashsize);
  sha1_hashxdr (si.ksc.base (), kscdat, true);

  if (sidp)
    sha1_hashxdr (sidp->base (), si, true);
  if (sip)
    *sip = si;

  bzero (si.kcs.base (), si.kcs.size ());
  bzero (si.ksc.base (), si.ksc.size ());
}

rexsess::rexsess (str path, cb_rex::ptr cb)
  : path (path), cb (cb)
{
  myauthno = 0;
  seqno = 1;
  sfs_connect_path (path, SFS_REX, wrap (this, &rexsess::connected));
}

rexsess::~rexsess ()
{
  bzero (&kscdat, sizeof (kscdat));
  bzero (&kcsdat, sizeof (kcsdat));
  if (cbres) {
    bzero (&cbres->resok->kcs, sizeof (cbres->resok->kcs));
    bzero (&cbres->resok->ksc, sizeof (cbres->resok->ksc));
  }
  sesstab.remove (this);
}

void
rexsess::fail ()
{
  ref<sfsagent_rex_res> r = New refcounted<sfsagent_rex_res> (false);
  cb (r);
}


void
rexsess::succeed (cb_rex::ptr cb)
{
  cbres->resok->seqno = ++seqno;
  cb (cbres);
}

void
rexsess::attached (rexd_attach_res *resp, clnt_stat err)
{
  if (err) {
    fatal << "FAILED (" << err << ")\n";
  }
  else if (*resp != SFS_OK) {
    // XXX
    fatal << "FAILED (attach err " << int (*resp) << ")\n";
  }
  delete resp;
  warnx << "attached\n";

  sessconn->x = axprt_crypt::alloc (sessconn->x->reclaim ());
  sessconn->x->encrypt (sessinfo.kcs.base (), sessinfo.kcs.size (),
			sessinfo.ksc.base (), sessinfo.ksc.size ());

  //rexclnt = aclnt::alloc (sessconn->x, rex_prog_1);
  //  rexclnt->seteofcb (wrap (this, &rexsess::eof));
  sesstab.insert (this);
  
  sess = New rexsession (path, sessconn->x);

  //todo: make this a command line option (for rex)
  bool forwagent = true;
  if (forwagent) {
    vec<str> suidcommand;
    suidcommand.setsize (2);
    suidcommand[0] = "suidconnect";
    suidcommand[1] = "agent";

    sess->makechannel (New refcounted <agentchannel> (suidcommand, path, wrap (this, &rexsess::succeed, cb)));
  }
  else
    succeed (cb);   
}

void
rexsess::attach ()
{
  rexd_attach_arg arg;

  arg.seqno = seqno++;
  seq2sessinfo (0, &arg.sessid, NULL);
  seq2sessinfo (arg.seqno, &arg.newsessid, &sessinfo);

  rexd_attach_res *resp = New rexd_attach_res;
  sessclnt->call (REXD_ATTACH, &arg, resp, 
		  wrap (this, &rexsess::attached, resp));
}

void
rexsess::spawned (rexd_spawn_res *resp, clnt_stat err)
{
  if (err) {
    warn << "FAILED (" << err << ")\n";
    fail ();
    return;
  }
  else if (resp->err != SFS_OK) {
    // XXX
    warn << "FAILED (spawn err " << int (resp->err) << ")\n";
    fail ();
    return;
  }
  warnx << "spawned\n";

  kcsdat.sshare = resp->resok->kmsg.kcs_share;
  kscdat.sshare = resp->resok->kmsg.ksc_share;
  delete resp;

  cbres = New refcounted<sfsagent_rex_res> (true);
  cbres->resok->kcs.kcs_share = kcsdat.cshare;
  cbres->resok->kcs.ksc_share = kcsdat.sshare;
  cbres->resok->ksc.kcs_share = kscdat.cshare;
  cbres->resok->ksc.ksc_share = kscdat.sshare;

  attach ();
}

void
rexsess::spawn ()
{
  rexd_spawn_arg arg;
  rnd.getbytes (arg.kmsg.kcs_share.base (),
		arg.kmsg.kcs_share.size ());
  rnd.getbytes (arg.kmsg.ksc_share.base (),
		arg.kmsg.ksc_share.size ());
  kcsdat.type = SFS_KCS;
  kcsdat.cshare = arg.kmsg.kcs_share;
  kscdat.type = SFS_KSC;
  kscdat.cshare = arg.kmsg.ksc_share;
  arg.command.setsize (1);
  arg.command[0] = "proxy";

  rexd_spawn_res *resp = New rexd_spawn_res;
  sessclnt->call (REXD_SPAWN, &arg, resp, wrap (this, &rexsess::spawned, resp),
		  authuint_create (myauthno));
}

void
rexsess::loggedin (sfs_loginres *lresp, clnt_stat err)
{
  if (err) {
    warn << "loggedin: error\n";
    fail ();
    return;
  }
  if (!lresp) {
    warn << "loggedin: lresp is NULL error\n";
    fail ();
    return;
  }

  switch (lresp->status) {
  case SFSLOGIN_OK:
    myauthno = *lresp->authno;
    break;
#if 0
  case SFSLOGIN_MORE:
    {
      sfscd_agentreq_arg arg;
      arg.aid = aid;
      arg.agentreq.set_type (AGENTCB_AUTHMORE);
      arg.agentreq.more->authinfo = sp->authinfo;
      arg.agentreq.more->seqno = seqno;
      arg.agentreq.more->challenge = *sres.resmore;
      cbase = cdc->call (SFSCDCBPROC_AGENTREQ, &arg, &ares,
			 wrap (this, &userauth::aresult));
      break;
    }
  case SFSLOGIN_BAD:
    ntries++;
    sendreq ();
    break;
  case SFSLOGIN_ALLBAD:
    finish (0);
    break;
#endif
  default:
    warn << "userauth: bad status in loginres!\n";
    fail ();
    return;
  }
 
  delete lresp;
  spawn ();
}

void
rexsess::dologin (ptr<sfsagent_auth_res> ares, clnt_stat err)
{
  if (err) {
    warn << "dologin: " << err << "\n";
    fail ();
    return;
  }
  else if (!ares->authenticate) {
    warn << "dologin: no certificate\n";
    fail ();
    return;
  }
  else {
    sfs_loginarg larg;
    larg.seqno = 1;
    larg.certificate = *ares->certificate;

    sfs_loginres *lresp = New sfs_loginres;
    sfsclnt->call (SFSPROC_LOGIN, &larg, lresp, 
		   wrap (this, &rexsess::loggedin, lresp));
  }
}

ptr<sfsagent_auth_res>
rexsess::signauthreq (sfsagent_authinit_arg *aa)
{
  key *k = keynum (aa->ntries);
  if (!k || aa->authinfo.type != SFS_AUTHINFO) {
    warn ("signauthreq: couldn't find key\n");
    return NULL;
  }

  ref<sfsagent_auth_res> res = New refcounted<sfsagent_auth_res> (true);
  sfs_autharg ar (SFS_AUTHREQ);
  sfs_signed_authreq sar;
  str rawsar;

  ar.req->usrkey = k->k->n;
  sar.type = SFS_SIGNED_AUTHREQ;
  sar.seqno = aa->seqno;
  bzero (sar.usrinfo.base (), sar.usrinfo.size ());

  if (!sha1_hashxdr (sar.authid.base (), aa->authinfo)
      || !(rawsar = xdr2str (sar))
      || !(ar.req->signed_req = k->k->sign_r (rawsar),
	   xdr2bytes (*res->certificate, ar))) {
    warn ("signauthreq: xdr failed\n");
    res->set_authenticate (false);
  }

  warn << aa->requestor << ": " << aa->authinfo.name << ":"
       << armor32 (str (aa->authinfo.hostid.base (),
			aa->authinfo.hostid.size ()))
       << " (" << implicit_cast<int> (aa->authinfo.service) << ")\n";

  return res;
}

void
rexsess::connected (ptr<sfscon> sc, str err)
{
  if (!sc) {
    warn << path << ": FAILED (" << err << ")\n";
    fail ();
    return;
  }

  sessconn = sc;
  sessclnt = aclnt::alloc (sc->x, rexd_prog_1);
  sfsclnt = aclnt::alloc (sc->x, sfs_program_1);

  //  sessclnt->seteofcb (wrap (this, &rexsess::eof));

  sfsagent_authinit_arg aarg;
  aarg.ntries = 0;
  aarg.requestor = "";
  aarg.seqno = 1;
  aarg.authinfo.type = SFS_AUTHINFO;
  aarg.authinfo.service = SFS_REX;
  aarg.authinfo.name = sc->servinfo.host.hostname;
  aarg.authinfo.hostid = sc->hostid;
  aarg.authinfo.sessid = sc->sessid;

  ptr<sfsagent_auth_res> ares = signauthreq (&aarg);
  if (!ares) {
    fail ();
    return;
  }

  dologin (ares, RPC_SUCCESS);
}

void
rex_connect (str path, cb_rex::ptr cb)
{
  if (rexsess *sp = sesstab[path]) {
    warn << "rexsess: hash lookup succeeded\n";
    sp->succeed (cb);
  }
  else
    vNew rexsess (path, cb);
}
