/*******************************************************************************
*
* McXtrace, x-ray tracing package
*         Copyright, All rights reserved
*         DTU Physics, Kgs. Lyngby, Denmark
*         Synchrotron SOLEIL, Saint-Aubin, France
*
* Component: MCPL_output
*
* %Identification
* Written by: Erik B Knudsen 
* Date: Aug 2016
* Origin: DTU Physics
*
* Detector-like component that writes photon state parameters into an mcpl-format
* binary, virtual-source photon file.
*
* %D
* Detector-like component that writes photon state parameters into an mcpl-format 
* binary, virtual-source photon file.
*
* MCPL is short for Monte Carlo Particle List, and is a new format for sharing events
* between e.g. MCNP(X), Geant4, and McXtrace.
*
* When used with MPI, the component will output #MPI nodes individual MCPL files that
* can be merged using the mcpltool.
*
* MCPL_output allows a few flags to tweak the output files:
* 1. If use_polarisation is unset (default) the polarisation vector will not be stored (saving space)
* 2. If doubleprec is unset (default) data will be stored as 32 bit floating points, effectively cutting the output file size in half.
* 3. Extra information may be attached to each ray in the form of a userflag, a user-defined variable wich is packed into 32 bits. If
* the user variable does not fit in 32 bits the value will be truncated and likely garbage. If more than one variable is to be attached to
* each photon this must be packed into the 32 bits.
*
* These features are set this way to keep file sizes as manageable as possible.
*
* Example: MCPL_output( filename="voutput", verbose=1, userflag="flag", userflagcomment="Photon Id" )
*
* %P
* INPUT PARAMETERS
*
* filename: [str]         Name of photon file to write. If not given, the component name will be used.
* verbose: [1]            If 1) Print summary information for created MCPL file. 2) Also print summary of first 10 particles information stored in the MCPL file. >2) Also print information for first 10 particles as they are being stored by McStas
* weight_mode: [1]        weight_mode=1: record initial ray count in MCPL stat:sum entry and rescale particle weights. weight_mode=2: classical (deprecated) mode of outputting particle weights directly.
* polarisationuse: [1]    Enable storing the polarisation state of the photon.
* doubleprec: [1]         Use double precision storage
* userflag: [1]           Extra variable to attach to each photon. The value of this variable will be packed into a 32 bit integer.
* userflagcomment: [str]  String variable to describe the userflag. If this string is empty (the default) no userflags will be stored.
* buffermax: [1]          Maximal number of events to save ( <= MAXINT), GPU/OpenACC only
* %E
*******************************************************************************/

DEFINE COMPONENT MCPL_output

SETTING PARAMETERS ( string filename=0, int weight_mode=-1, int verbose=0,
                     int polarisationuse=0, int doubleprec=0,
                     string userflag="",  string userflagcomment="",
                     buffermax=0 )

DEPENDENCY "@MCPLFLAGS@"

SHARE
%{
  #include <mcpl.h>
  #include <sys/stat.h>
  int
  mcpl_file_exist (char* filename) {
    struct stat buffer;
    return (stat (filename, &buffer) == 0);
  }
%}

DECLARE
%{
  mcpl_outfile_t outputfile;
  mcpl_particle_t* particle;
  int userflagenabled;
  char* outfilename; // "/some/where/myfile.mcpl.gz"
  double weight_scale;

  // OPENACC variables:
  //(Note that cogen does not respect ifdefs here, so keep them also in non-acc mode)
  DArray1d X;
  DArray1d Y;
  DArray1d Z;
  DArray1d KX;
  DArray1d KY;
  DArray1d KZ;
  DArray1d EX;
  DArray1d EY;
  DArray1d EZ;
  DArray1d PHI;
  DArray1d T;
  DArray1d P;
  DArray1d U;
  int captured;
%}

INITIALIZE
%{
  {
    char* tmp = mcfull_file (filename[0] ? filename : NAME_CURRENT_COMP, NULL);
    outfilename = mcpl_name_helper (tmp, 'G');
    free (tmp);
  }

  #if defined (USE_MPI)
  unsigned long iproc = mpi_node_rank;
  unsigned long nproc = mpi_node_count;
  #else
  unsigned long iproc = 0;
  unsigned long nproc = 1;
  #endif

  outputfile = mcpl_create_outfile_mpi (outfilename, iproc, nproc);

  mcpl_enable_universal_pdgcode (outputfile, 22); /*all particles are photons*/

  char line[256];
  snprintf (line, sizeof (line), "%s %s %s", MCCODE_NAME, MCCODE_VERSION, instrument_name);
  mcpl_hdr_set_srcname (outputfile, line);
  snprintf (line, sizeof (line), "Output by COMPONENT: %s", NAME_CURRENT_COMP);
  mcpl_hdr_add_comment (outputfile, line);

  /*also add the instrument file and the command line as blobs*/
  if (mcpl_file_exist (instrument_source)) {
    uint64_t instrsrc_buflen;
    char* instrsrc_buf = NULL;
    mcpl_read_file_to_buffer (instrument_source,
                              100 * 1024 * 1024, // 100mb, must be less than uint32 max
                              1,                 // text,
                              &instrsrc_buflen, &instrsrc_buf);
    mcpl_hdr_add_data (outputfile, "mccode_instr_file", (uint32_t)instrsrc_buflen, instrsrc_buf);
    free (instrsrc_buf);
  } else {
    fprintf (stderr,
             "\nWarning (%s): Source instrument file (%s)"
             " not found, hence not embedded.\n",
             NAME_CURRENT_COMP, instrument_source);
  }

  {
    char clr[2048];
    char* clrp = clr;
    clrp = clr;
    clrp += snprintf (clrp, sizeof (clr), "%s", instrument_exe);
    char Parameters[CHAR_BUF_LENGTH];
    for (int ii = 0; ii < numipar; ii++) {
      (*mcinputtypes[mcinputtable[ii].type].printer) (Parameters, mcinputtable[ii].par);
      clrp += snprintf (clrp, sizeof (clr) - (clrp - clr), " %s=%s", mcinputtable[ii].name, Parameters);
    }
    *clrp = '\0';
    mcpl_hdr_add_data (outputfile, "mccode_cmd_line", strlen (clr), clr);
  }

  if (weight_mode == -1) {
    MPI_MASTER (fprintf (stderr,
                         "\nWarning (%s): weight_mode unspecified so defaulting to 2"
                         " (classic mode) which will soon become unavailable and tooling"
                         " might need updating. Please consult"
                         " https://github.com/mccode-dev/McCode/discussions/2076 for more"
                         " information.\n",
                         NAME_CURRENT_COMP););
    weight_mode = 2;
  } else if (weight_mode == 2) {
    MPI_MASTER (fprintf (stderr,
                         "\nWarning (%s): weight_mode=2 selected. This mode will"
                         " soon become unavailable. Please consult"
                         " https://github.com/mccode-dev/McCode/discussions/2076 for more"
                         " information.\n",
                         NAME_CURRENT_COMP););
  }
  if (weight_mode != 1 && weight_mode != 2) {
    fprintf (stderr, "\nERROR (%s): Unvalid value of weight_mode provided.\n", NAME_CURRENT_COMP);
    exit (1);
  }

  weight_scale = 1.0;
  if (weight_mode != 2) {
    mcpl_hdr_add_stat_sum (outputfile, "initial_ray_count", -1.0); // reserve
    // Multiply with mcget_ncount() to undo division supposedly performed in
    // source:
    weight_scale = (double)mcget_ncount ();
  }

  if (polarisationuse)
    mcpl_enable_polarisation (outputfile);

  if (doubleprec)
    mcpl_enable_doubleprec (outputfile);

  userflagenabled = 0;
  if (userflagcomment && strlen (userflagcomment) != 0) {
    userflagenabled = 1;
    mcpl_enable_userflags (outputfile);
    snprintf (line, 255, "userflags: %s", userflagcomment);
    mcpl_hdr_add_comment (outputfile, line);
  }

  #ifndef OPENACC
  /*pointer to the single particle storage area*/
  particle = mcpl_get_empty_particle (outputfile);
  #else
  particle = NULL;
  if (!buffermax) {
    buffermax = mcget_ncount ();
  }
  X = create_darr1d (buffermax);
  X = create_darr1d (buffermax);
  Y = create_darr1d (buffermax);
  Z = create_darr1d (buffermax);
  KX = create_darr1d (buffermax);
  KY = create_darr1d (buffermax);
  KZ = create_darr1d (buffermax);
  EX = create_darr1d (buffermax);
  EY = create_darr1d (buffermax);
  EZ = create_darr1d (buffermax);
  PHI = create_darr1d (buffermax);
  T = create_darr1d (buffermax);
  P = create_darr1d (buffermax);
  if (userflagenabled) {
    U = create_darr1d (buffermax);
  }
  captured = 0;
  #endif
%}

TRACE
%{
  #ifdef OPENACC
  int cap;
  #pragma acc atomic capture
  {
    cap = captured++;
  }

  //  unsigned long long i=_particle->_uid;// % GPU_INNERLOOP;
  if (cap < ceil (buffermax)) {
    X[cap] = x;
    Y[cap] = y;
    Z[cap] = z;
    KX[cap] = kx;
    KY[cap] = ky;
    KZ[cap] = kz;
    EX[cap] = Ex;
    EY[cap] = Ey;
    EZ[cap] = Ez;
    PHI[cap] = phi;
    T[cap] = t;
    P[cap] = p;
    if (userflagenabled) {
      int fail;
      double uvar = particle_getvar (_particle, userflag, &fail);
      if (fail)
        uvar = 0;
      U[cap] = uvar;
    }
    SCATTER;
  }

  #else
  double nrm;
  /*positions are in cm*/
  particle->position[0] = x * 100;
  particle->position[1] = y * 100;
  particle->position[2] = z * 100;

  if (polarisationuse) {
    particle->polarisation[0] = Ex;
    particle->polarisation[1] = Ey;
    particle->polarisation[2] = Ez;
  }

  nrm = sqrt (kx * kx + ky * ky + kz * kz);
  /*ekin is in MeV, in McXtrace we use keV*/
  particle->ekin = K2E * nrm * 1e-3;
  particle->direction[0] = kx / nrm;
  particle->direction[1] = ky / nrm;
  particle->direction[2] = kz / nrm;
  /*time in ms:*/
  particle->time = t * 1e3;
  /*weight in unspecified units:*/
  particle->weight = p * weight_scale;
  if (userflagenabled) {
    // TODO: Reconsider this passing of uint32_t flags through a double
    int fail;
    double uvar = particle_getvar (_particle, userflag, &fail);
    if (fail)
      uvar = 0;
    particle->userflags = (uint32_t)uvar;
  }

  MPI_MASTER (if (verbose == 3 && mcrun_num < 10) {
    printf ("id=%lld\tpdg=22\tekin=%g MeV\tx=%g cm\ty=%g cm\tz=%g cm\tux=%g\tuy=%g\tuz=%g\tt=%g ms\tweight=%g\tpolx=%g\tpoly=%g\tpolz=%g\n", mcrun_num,
            particle->ekin, particle->position[0], particle->position[1], particle->position[2], particle->direction[0], particle->direction[1],
            particle->direction[2], particle->time, particle->weight, particle->polarisation[0], particle->polarisation[1], particle->polarisation[2]);
  });

  mcpl_add_particle (outputfile, particle);

  SCATTER;
  #endif
%}

SAVE
%{
  #ifdef OPENACC
  double nrm;
  unsigned long long i;
  if (captured > ceil (buffermax)) {
    fprintf (stderr, "MCPL_output captured %g particles which is more than the buffersize (%g)!\n", (double)captured, buffermax);
  }
  mcpl_particle_t Particle;
  for (i = 0; i < captured; i++) {
    if (P[i] > 0) {
      /*positions are in cm*/
      Particle.position[0] = X[i] * 100;
      Particle.position[1] = Y[i] * 100;
      Particle.position[2] = Z[i] * 100;

      if (polarisationuse) {
        Particle.polarisation[0] = EX[i];
        Particle.polarisation[1] = EY[i];
        Particle.polarisation[2] = EZ[i];
      }

      nrm = sqrt (KX[i] * KX[i] + KY[i] * KY[i] + KZ[i] * KZ[i]);
      /*ekin is in MeV, in McXtrace keV*/
      Particle.ekin = K2E * nrm * 1e-3;
      Particle.direction[0] = KX[i] / nrm;
      Particle.direction[1] = KY[i] / nrm;
      Particle.direction[2] = KZ[i] / nrm;
      /*time in ms:*/
      Particle.time = T[i] * 1e3;
      /*weight in unspecified units:*/
      Particle.weight = P[i] * weight_scale;
      if (userflagenabled)
        Particle.userflags = (uint32_t)U[i];

      if (verbose == 3 && mcrun_num < 10) {
        printf ("id=%ld\tpdg=2112\tekin=%g MeV\tx=%g cm\ty=%g cm\tz=%g cm\tux=%g\tuy=%g\tuz=%g\tt=%g ms\tweight=%g\tpolx=%g\tpoly=%g\tpolz=%g\n", mcrun_num,
                Particle.ekin, Particle.position[0], Particle.position[1], Particle.position[2], Particle.direction[0], Particle.direction[1],
                Particle.direction[2], Particle.time, Particle.weight, Particle.polarisation[0], Particle.polarisation[1], Particle.polarisation[2]);
      }

      mcpl_add_particle (outputfile, &Particle);
    }
  }
  #endif
%}

FINALLY
%{
  if (weight_mode != 2)
    mcpl_hdr_add_stat_sum (outputfile, "initial_ray_count", (double)mcget_ncount ());
  mcpl_closeandgzip_outfile (outputfile);

  #ifdef USE_MPI
  MPI_Barrier (MPI_COMM_WORLD);
  #endif
  #if defined (USE_MPI)
  unsigned long iproc = mpi_node_rank;
  unsigned long nproc = mpi_node_count;
  #else
  unsigned long iproc = 0;
  unsigned long nproc = 1;
  #endif

  if (iproc == 0)
    mcpl_merge_outfiles_mpi (outfilename, nproc);
  if (verbose) {
    MPI_MASTER (printf ("\n\nMCPL output summary from %s\n", outfilename); mcpl_dump (outfilename, 0, 0, 10););
  }
  #ifdef OPENACC
  destroy_darr1d (X);
  destroy_darr1d (Y);
  destroy_darr1d (Z);
  destroy_darr1d (KX);
  destroy_darr1d (KY);
  destroy_darr1d (KZ);
  destroy_darr1d (EX);
  destroy_darr1d (EY);
  destroy_darr1d (EZ);
  destroy_darr1d (PHI);
  destroy_darr1d (T);
  destroy_darr1d (P);
  #endif
  free (outfilename);
%}

MCDISPLAY
%{
  double t, dt;
  int i;
  multiline (5, 0.2, 0.2, 0.0, -0.2, 0.2, 0.0, -0.2, -0.2, 0.0, 0.2, -0.2, 0.0, 0.2, 0.2, 0.0);
  /*M*/
  multiline (5, -0.085, -0.085, 0.0, -0.085, 0.085, 0.0, -0.045, -0.085, 0.0, -0.005, 0.085, 0.0, -0.005, -0.085, 0.0);
  /*O*/
  dt = 2 * M_PI / 32;
  t = 0;
  for (i = 0; i < 32; i++) {
    line (0.04 * cos (t) + 0.045, 0.08 * sin (t), 0, 0.04 * cos (t + dt) + 0.045, 0.08 * sin (t + dt), 0);
    t += dt;
  }
%}

END
