/* mincPLS.c                                                                  */
/*                                                                            */
/* Re-implementation of Randy McIntosh's PLS stuff for MINC files             */
/*                                                                            */
/* Andrew Janke - a.janke@gmail.com                                           */
/*                                                                            */
/* Copyright Andrew Janke, McConnell Brain Imaging Centre                     */
/* Permission to use, copy, modify, and distribute this software and its      */
/* documentation for any purpose and without fee is hereby granted,           */
/* provided that the above copyright notice appear in all copies.  The        */
/* author and McGill University make no representations about the             */
/* suitability of this software for any purpose.  It is provided "as is"      */
/* without express or implied warranty.                                       */

#include <config.h>
#include <stdlib.h>
#include <stdio.h>
#include <sys/time.h>
#include <unistd.h>
#include <math.h>
#include <string.h>
#include <ParseArgv.h>
#include <time_stamp.h>
#include <voxel_loop.h>
#include <gsl/gsl_vector.h>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_linalg.h>
#include <minc_vector_io.h>
#include <volume_io.h>

#ifndef FALSE
#  define FALSE 0
#endif
#ifndef TRUE
#  define TRUE 1
#endif

#define DEFAULT_BOOL -1
#define MAX_NUM_ROWS 1000
#define MAX_NUM_COLUMNS 1000

#define SQR2(x) ((x) * (x))


/* typedefs */
typedef struct {
   /* input parameters */
   int      masking;
   double   mask_val;
   int      mask_idx;

   int      n_contrasts;
   int      n_infiles;

   int      pls_nvox;

   gsl_matrix *pls_X;
   gsl_matrix *pls_S;
   gsl_matrix *pls_B;

   } Loop_Data;

/* function prototypes */
void     correlation_loop(void *caller_data, long num_voxels, int input_num_buffers,
                 int input_vector_length, double *input_data[],
                 int output_num_buffers, int output_vector_length,
                 double *output_data[], Loop_Info * loop_info);
void     output_B_loop(void *caller_data, long num_voxels, int input_num_buffers,
                 int input_vector_length, double *input_data[],
                 int output_num_buffers, int output_vector_length,
                 double *output_data[], Loop_Info * loop_info);
int get_design_matrix(char *design_fname, char ***infiles, gsl_matrix **design_matrix);
void     print_version_info(void);

/* Argument variables and table */
static int verbose = FALSE;
static int clobber = FALSE;
nc_type datatype = MI_ORIGINAL_TYPE;
int is_signed = FALSE;
double valid_range[2] = {0.0, 0.0};
int copy_all_header = DEFAULT_BOOL;
static int max_buffer = 4 * 1024;
int check_dim_info = TRUE;
static char *mask_fname = NULL;
static Loop_Data md = {
   FALSE, 1.0, 0, 
   0, 0, 0,
   NULL, NULL, NULL,
   };

static ArgvInfo argTable[] = {
   {NULL, ARGV_HELP, (char *)NULL, (char *)NULL,
    "General options:"},
   {"-version", ARGV_FUNC, (char *)print_version_info, (char *)NULL,
    "print version info and exit."},
   {"-verbose", ARGV_CONSTANT, (char *)TRUE, (char *)&verbose,
    "print out extra information."},
   {"-clobber", ARGV_CONSTANT, (char *)TRUE, (char *)&clobber,
    "clobber existing files."},
   
   {"-filetype", ARGV_CONSTANT, (char *) MI_ORIGINAL_TYPE, (char *) &datatype,
    "use the data type of first file (default)."},
   {"-byte", ARGV_CONSTANT, (char *) NC_BYTE, (char *) &datatype,
    "write out byte data."},
   {"-short", ARGV_CONSTANT, (char *) NC_SHORT, (char *) &datatype,
    "write out short integer data."},
   {"-int", ARGV_CONSTANT, (char *) NC_INT, (char *) &datatype,
    "write out 32-bit integer data."},
   {"-float", ARGV_CONSTANT, (char *) NC_FLOAT, (char *) &datatype,
    "write out single-precision floating-point data."},
   {"-double", ARGV_CONSTANT, (char *) NC_DOUBLE, (char *) &datatype,
    "write out double-precision floating-point data."},
   {"-signed", ARGV_CONSTANT, (char *) TRUE, (char *) &is_signed,
    "write signed integer data."},
   {"-unsigned", ARGV_CONSTANT, (char *) FALSE, (char *) &is_signed,
    "write unsigned integer data (default if type specified)."},
   {"-range", ARGV_FLOAT, (char *) 2, (char *) valid_range,
    "valid range for output data."},

   {"-copy_header", ARGV_CONSTANT, (char *) TRUE, (char *) &copy_all_header,
    "copy all of the header from the first file."},
   {"-nocopy_header", ARGV_CONSTANT, (char *) FALSE, (char *) &copy_all_header,
    "do not copy all of the header from the first file."},    
   {"-check_dimensions", ARGV_CONSTANT, (char *) TRUE, (char *) &check_dim_info,
    "Check that files have matching dimensions (default)."},
   {"-nocheck_dimensions", ARGV_CONSTANT, (char *) FALSE, (char *) &check_dim_info,
    "Do not check that files have matching dimensions."},

   {"-max_buffer", ARGV_INT, (char *)1, (char *)&max_buffer,
    "maximum size of buffers (in kbytes)"},
   {"-mask", ARGV_STRING, (char *)1, (char *)&mask_fname,
    "only perform calculations on voxels within the specified mask"},
   {"-mask_val", ARGV_FLOAT, (char *)1, (char *)&(md.mask_val),
    "mask value to use"},

   {NULL, ARGV_HELP, (char *)NULL, (char *)NULL,
    "\nContrasts:"},

   {NULL, ARGV_HELP, (char *)NULL, (char *)NULL,
    "\nOutput Options:"},

   {NULL, ARGV_HELP, NULL, NULL, ""},
   {NULL, ARGV_END, NULL, NULL, NULL}
   };

int main(int argc, char *argv[])
{
   char   **infiles;
   char    *design_fname;
   char    *out_base;
   char   **outfiles;
   int      n_infiles, n_outfiles;
   char    *arg_string;
   int      i, j, n_voxels, n_contrasts;
   char    *buf;
   FILE    *fh;
   
   int      n_dims;
   int      sizes[MAX_DIMENSIONS];
   volume_input_struct input_info;
   Volume tmp_vol;
   
   /* for the SVD */
   gsl_matrix *svd_A;
   gsl_matrix *svd_B;
   gsl_vector *svd_D;
   gsl_vector *svd_work;
   
   
   Loop_Options *loop_opts;

   gsl_matrix *design_matrix;

 
   /* Save time stamp and args */
   arg_string = time_stamp(argc, argv);

   /* get arguments */
   if(ParseArgv(&argc, argv, argTable, 0) || (argc != 3)){
      fprintf(stderr, "\nUsage: %s [options] <design-matrix.dm> <out_base>\n", argv[0]);
      fprintf(stderr, "       %s -help\n\n", argv[0]);
      exit(EXIT_FAILURE);
      }

   /* get and check for the input design matrix and output base filenames */
   design_fname = argv[1];
   out_base = argv[2];

   if(access(design_fname, F_OK) != 0){
      fprintf(stderr, "%s: Couldn't find input design-matrix: %s\n\n", argv[0],
              design_fname);
      exit(EXIT_FAILURE);
      }

   /* read in the design matrix (X) */
   if(get_design_matrix(design_fname, &infiles, &design_matrix) != TRUE){
      fprintf(stderr, "Error: error in get_design_matrix_file. \n");
      exit(EXIT_FAILURE);
      }
   n_infiles = design_matrix->size1;
   n_contrasts = design_matrix->size2;

   /* a bit of output */
   if(verbose){
      fprintf(stdout, "+++ design matrix from %s +++\n", design_fname);
      for(i = 0; i < design_matrix->size1; i++){
         fprintf(stdout, " | [%s] - ", infiles[i]);
         for(j = 0; j < design_matrix->size2; j++){
            fprintf(stdout, "%5g ", gsl_matrix_get(design_matrix, i, j));
            }
         fprintf(stdout, "\n");
         }
      }

   /* set up and check for output files */
   n_outfiles = n_contrasts;
   outfiles = (char **)malloc(sizeof(char *) * n_outfiles);
   for(i = 0; i < n_outfiles; i++){
      outfiles[i] = (char *)malloc((strlen(out_base) + 10) * sizeof(char));
      sprintf(outfiles[i], "%s.SI%02d.mnc", out_base, i);

      if(access(outfiles[i], F_OK) == 0 && !clobber){
         fprintf(stderr, "%s: %s exists, use -clobber to overwrite\n\n",
                 argv[0], outfiles[i]);
         exit(EXIT_FAILURE);
         }
      }

   /* great big long-winded way to get the number of voxels in a minc file */
   n_dims = get_minc_file_n_dimensions(infiles[0]);
   if(verbose){ 
      fprintf(stdout, "Got %d dimensions from %s\n", n_dims, infiles[0]);
      }
   if(start_volume_input(infiles[0], n_dims, NULL,
                         NC_UNSPECIFIED, FALSE, 0.0, 0.0,
                         TRUE, &tmp_vol, (minc_input_options*)NULL, 
                         &input_info) != OK){
      fprintf(stderr, "[%s]: error reading in minc file %s. \n", argv[0], infiles[0]);
      exit(EXIT_FAILURE);
      }
   get_volume_sizes(tmp_vol, sizes);
   cancel_volume_input(tmp_vol, &input_info);
   n_voxels = 1;
   for(i=0; i<n_dims; i++){
      n_voxels *= sizes[i];
      }
   if(verbose){
      fprintf(stdout, "   # voxels: %d\n", n_voxels);
      }

   /* alloc space for the S matrix */
   md.pls_S = gsl_matrix_alloc(n_voxels, n_contrasts);

   /* set up for our first voxel_loop (setting up the S matrix) */
   md.pls_nvox = 0;
   md.n_contrasts = n_contrasts;
   md.n_infiles = n_infiles;
   md.pls_X = design_matrix;

   /* set up voxel loop options */
   loop_opts = create_loop_options();
   set_loop_verbose(loop_opts, TRUE);
   set_loop_clobber(loop_opts, clobber);
   
   set_loop_datatype(loop_opts, datatype, is_signed,
                     valid_range[0], valid_range[1]);

   set_loop_copy_all_header(loop_opts, copy_all_header);
   set_loop_buffer_size(loop_opts, (long)1024 * max_buffer);
   set_loop_check_dim_info(loop_opts, check_dim_info);


   /* do the loop to build the S matrix from infiles (Y) and design matrix (X)*/
   if(verbose){
      fprintf(stdout, "Creating cross correlation matrix from X and Y matrix (S)\n");
      }
   voxel_loop(n_infiles, infiles, 0, NULL, arg_string,
              loop_opts, correlation_loop, (void *)&md);

   /* alloc space for SVD matrix doo-hickeys */
   svd_B = gsl_matrix_alloc(n_voxels, n_contrasts);
   svd_A = gsl_matrix_alloc(n_contrasts, n_contrasts);
   svd_D = gsl_vector_alloc(n_contrasts);
   svd_work = gsl_vector_alloc(n_contrasts);

   /* make a working copy of the S matrix for the SVD as it is overwritten */
   gsl_matrix_memcpy(svd_B, md.pls_S);

   /* decompose the S matrix using SVD (S = ADB^t) */
   gsl_linalg_SV_decomp(svd_B, svd_A, svd_D, svd_work);
 
 
   /* allocate space for filename buffer */
   buf = (char*)malloc((strlen(out_base) + 20) * sizeof(char));
   
   
   /* write out the singular values (D) */
   sprintf(buf, "%s.sing-values.txt", out_base);
   fh = fopen(buf, "w");

   fprintf(fh, "# Singular Values (D)\n");
   fprintf(fh, "# -------------------\n");
   for(i=0; i<md.n_contrasts; i++){
      fprintf(fh, "SI%02d  %g\n", i, gsl_vector_get(svd_D, i));
      } 
   fprintf(fh, "\n"); 
   fclose(fh);
   
  
   /* write out the singular vectors (A) */
   sprintf(buf, "%s.sing-vectors.txt", out_base);
   fh = fopen(buf, "w");
   fprintf(fh, "# Singular Vectors for the Design Matrix (A)\n");
   fprintf(fh, "# ------------------------------------------\n");
  
   fprintf(fh, "# Contrast"); 
   for(i=0; i<md.n_contrasts; i++){
      fprintf(fh, "     SI%02d        ", i);
      }
   fprintf(fh, "\n");
   
   for(i=0; i<md.n_contrasts; i++){
      fprintf(fh, "%02d", i);
      for(j=0; j<md.n_contrasts; j++){
         fprintf(fh, "   %14g", gsl_matrix_get(svd_A, i, j));
         }
      fprintf(fh, "\n"); 
      } 
   fprintf(fh, "\n"); 
   fclose(fh);


   /* output the Salience Images as MINC files (B) */
   if(verbose){
      fprintf(stdout, "Outputting Salience images (B)\n");
      }
   md.pls_nvox = 0;
   md.pls_B = svd_B;
   voxel_loop(1, infiles, n_outfiles, outfiles, arg_string,
              loop_opts, output_B_loop, (void *)&md);

   /* be tidy */   
   free_loop_options(loop_opts);

   return (EXIT_SUCCESS);
   }

/* compute the normalised inner product of X (design matrix) and Y (data) */
void correlation_loop(void *caller_data, long num_voxels, int input_num_buffers,
             int input_vector_length, double *input_data[], int output_num_buffers,
             int output_vector_length, double *output_data[], Loop_Info * loop_info)
{
   Loop_Data *md = (Loop_Data *) caller_data;
   int      i, cont, ivox;
   double   mask_value;
   
   double value_X, value_Y;
   double sum_X, sum_Y;
   double ssum_X, ssum_Y;
   double mean_X, mean_Y;
   double sd_X, sd_Y;

   double result;

   /* shut the compiler up */
   // (void)output_num_buffers;
   // (void)output_vector_length;
 
   /* for each voxel */
   for(ivox = 0; ivox < num_voxels * input_vector_length; ivox++){

      /* nasty way that works for masking or not */
      mask_value = 0;
      if(!md->masking ||
         (md->masking && fabs(input_data[md->mask_idx][ivox] - md->mask_val) < 0.5)){

         /* for each contrast */
         for(cont=0; cont < md->n_contrasts; cont++){

            /* first calculate the mean and stddev of each column of X and Y */
            sum_X = sum_Y = 0;
            ssum_X = ssum_Y = 0;
            for(i = 0; i < md->n_infiles; i++){
               sum_X += gsl_matrix_get(md->pls_X, i, cont);
               sum_Y += input_data[i][ivox];
               ssum_X += SQR2(gsl_matrix_get(md->pls_X, i, cont));
               ssum_Y += SQR2(input_data[i][ivox]);
               }
            mean_X = sum_X / md->n_infiles;
            mean_Y = sum_Y / md->n_infiles;
            sd_X = sqrt((ssum_X - (md->n_infiles * SQR2(mean_X))) / (md->n_infiles - 1));
            sd_Y = sqrt((ssum_Y - (md->n_infiles * SQR2(mean_Y))) / (md->n_infiles - 1));
            
            /* calculate the inner product after noramlising the columns (zscore) */
            result = 0;
            for(i = 0; i < md->n_infiles; i++){
               value_X = (gsl_matrix_get(md->pls_X, i, cont) - mean_X) / sd_X;
               value_Y = (input_data[i][ivox] - mean_Y) / sd_Y;
               result += value_X * value_Y;
               }
            result /= (md->n_infiles - 1);
            
//            fprintf(stderr, "\n[%5d:%5d]  %g", md->pls_nvox, cont, result);
            
            /* store result in S */
            gsl_matrix_set(md->pls_S, md->pls_nvox, cont, result);
            }
         }
      
      /* increment our counter */
      md->pls_nvox++;
      }
   }

/* loop function to output the B matrix */
void output_B_loop(void *caller_data, long num_voxels, int input_num_buffers,
             int input_vector_length, double *input_data[], int output_num_buffers,
             int output_vector_length, double *output_data[], Loop_Info * loop_info)
{
   Loop_Data *md = (Loop_Data *) caller_data;
   int      i, ivox;
   double   mask_value;
   
   /* shut the compiler up */
   // (void)output_num_buffers;
   // (void)output_vector_length;
 
   /* for each voxel */
   for(ivox = 0; ivox < num_voxels * input_vector_length; ivox++){

      /* nasty way that works for masking or not */
      mask_value = 0;
      if(!md->masking ||
         (md->masking && fabs(input_data[md->mask_idx][ivox] - md->mask_val) < 0.5)){

         /* write results out */
         for(i = 0; i < output_num_buffers; i++){
            output_data[i][ivox] = gsl_matrix_get(md->pls_B, md->pls_nvox, i);
            }
         }
      
      /* increment our counter */
      md->pls_nvox++;
      }
   }

/* this is largely a mash from glim_image except we use GSL here */
int get_design_matrix(char *design_fname, char ***infiles, gsl_matrix **design_matrix)
{
   FILE    *fp;
   int      num_columns, num_rows;
   int      i, j, prev_j;

   gsl_matrix *tmp_matrix;

   STRING   filename;
   char     tmpc;
   double   value;
   int      done;
   
   /* open the design file */
   fp = fopen(design_fname, "r");
   if(fp == NULL){
      fprintf(stderr, "\nError opening design-matrix %s.\n\n", design_fname);
      return (EXIT_FAILURE);
      }

   /* allocate tmp space */
   tmp_matrix = gsl_matrix_alloc(MAX_NUM_ROWS, MAX_NUM_COLUMNS);

   *infiles = malloc(sizeof(char *) * (MAX_NUM_ROWS + 2));
   filename = alloc_string(4096);

   if(verbose){
      fprintf(stdout, "Reading Design file [%s]\n", design_fname);
      }
   
   /* for each line of the input file... */
   i = 0;
   prev_j = -1;
   while(mni_input_string(fp, &filename, (char)' ', (char)0) == OK){

      /* allocate some space for the filename */
      (*infiles)[i] = malloc(sizeof(char) * (strlen(filename) + 1));

      if(strcpy((*infiles)[i], filename) == NULL){
         fprintf(stderr, "\nError getting name of infile %d in %s.",
                 i + 1, design_fname);
         exit(EXIT_FAILURE);
         }

      /* now get the values */
      j = 0;
      done = FALSE;
      while(!done){

         /* get the value */
         input_double(fp, &value);
         gsl_matrix_set(tmp_matrix, i, j, value);

         /* check if we are done */
         input_character(fp, &tmpc);
         done = (tmpc == '\n');
         unget_character(fp, tmpc);

         j++;
         if(j > MAX_NUM_COLUMNS){
            fprintf(stderr,
                    "\nError: more columns in design_matrix [%d] than MAX_NUM_COLUMNS [%d]\n",
                    j, MAX_NUM_COLUMNS);
            exit(EXIT_FAILURE);
            }
         }

      /* check number of values */
      if((prev_j != -1) && (j != prev_j)){
         fprintf(stderr, "\nError: mismatched number of entries on line [%d]\n", i);
         exit(EXIT_FAILURE);
         }
      prev_j = j;

      i++;
      if(i > MAX_NUM_ROWS){
         fprintf(stderr,
                 "\nError: more rows in design_matrix [%d] than MAX_NUM_ROWS [%d]\n", i,
                 MAX_NUM_ROWS);
         exit(EXIT_FAILURE);
         }
      }
   num_rows = i;
   num_columns = j;

   /* populate the design matrix */
   *design_matrix = gsl_matrix_alloc(num_rows, num_columns);
   for(i = 0; i < num_rows; i++){
      for(j = 0; j < num_columns; j++){
         gsl_matrix_set(*design_matrix, i, j, gsl_matrix_get(tmp_matrix, i, j));
         }
      }

   /* clean up */
   fclose(fp);
   delete_string(filename);
   gsl_matrix_free(tmp_matrix);

   return TRUE;
   }

void print_version_info(void)
{
   fprintf(stdout, "\n");
   fprintf(stdout, "%s version %s\n", PACKAGE, VERSION);
   fprintf(stdout, "Comments to %s\n", PACKAGE_BUGREPORT);
   fprintf(stdout, "\n");
   exit(EXIT_SUCCESS);
   }
