#include <stdio.h>

#include <stdlib.h>

#include <math.h>

#include <assert.h>

#include "imatrix.h"

#include "ETF.h"


#include "stackmem.h"


#ifdef MEMDEBUG

#include "mnemosyne.h"

#endif


extern int debug;

#define round(x) ((int) ((x) + 0.5))


ETF *New_ETF(int i, int j) {
  ETF *ETF = (struct ETF *)claim(sizeof(struct ETF)); assert(ETF);
  ETF->max_grad = 1.0; /* moving this to the start instead of the end of this code
                          fixed the bug where ETF->p[0] appeared to be overwritten */
  ETF->Nr = i; ETF->Nc = j;
  ETF->p = (Vect **)claim((sizeof(struct Vect *)) * ETF->Nr); assert(ETF->p);
  for (i = 0; i < ETF->Nr; i++) {
    ETF->p[i] = (Vect *)claim((sizeof(struct Vect)) * ETF->Nc); assert(ETF->p[i]);
    for (j = 0; j < ETF->Nc; j++) { /* unnecessary but may help find bug */
      ETF->p[i][j].tx=0.0; ETF->p[i][j].ty=0.0; ETF->p[i][j].mag=0.0;
    }
  }
  return ETF;
}

void ETF_copy(ETF *this, ETF *s) {
  int i, j;
  assert(this->Nr == s->Nr);
  assert(this->Nc == s->Nc);
  for (i = 0; i < s->Nr; i++) {
    for (j = 0; j < s->Nc; j++) {
      this->p[i][j].tx = s->p[i][j].tx;
      this->p[i][j].ty = s->p[i][j].ty;
      this->p[i][j].mag = s->p[i][j].mag;
    }
  }
  this->Nr = s->Nr;
  this->Nc = s->Nc;
  this->max_grad = s->max_grad;
}

#define ETF_getRow(ETF) ((ETF)->Nr)

#define ETF_getCol(ETF) ((ETF)->Nc)

#define GetMaxGrad(ETF) ((ETF)->max_grad)


double ETF_crc(ETF *e) {
  int i, j;
  double crc = 0.0;
  for (i = 0; i < e->Nr; i++) {
    for (j = 0; j < e->Nc; j++) {
      /*if (i < 16 && j < 8) fprintf(stderr, "%06x ", e->p[i][j]);*/
      crc = (crc + e->p[i][j].tx);
      crc = (crc + e->p[i][j].ty);
      crc = (crc + e->p[i][j].mag);
    }
    /*if (i < 16) fprintf(stderr, "\n");*/
  }
  return crc;
}

#include "myvec.h"


#ifdef NEVER

void set(ETF *this, imatrix *image) {
  int i, j;
  double MAX_VAL = 1020.; 
  double v[2];
  
  this->max_grad = -1.;
  
  for (i = 1; i < image->Nr - 1; i++) { 
    for (j = 1; j < image->Nc - 1; j++) {
      /* ////////////////////////////////////////////////////////////// */
      this->p[i][j].tx = (image->p[i+1][j-1] + 2*(double)image->p[i+1][j] + image->p[i+1][j+1] 
			  - image->p[i-1][j-1] - 2*(double)image->p[i-1][j] - image->p[i-1][j+1]) / MAX_VAL;
      this->p[i][j].ty = (image->p[i-1][j+1] + 2*(double)image->p[i][j+1] + image->p[i+1][j+1]
			  - image->p[i-1][j-1] - 2*(double)image->p[i][j-1] - image->p[i+1][j-1]) / MAX_VAL;
      /* /////////////////////////////////////////// */
      v[0] = this->p[i][j].tx;
      v[1] = this->p[i][j].ty;
      this->p[i][j].tx = -v[1];
      this->p[i][j].ty = v[0];
      /* //////////////////////////////////////////// */
      this->p[i][j].mag = sqrt(this->p[i][j].tx * this->p[i][j].tx + this->p[i][j].ty * this->p[i][j].ty);
      
      if (this->p[i][j].mag > this->max_grad) {
	this->max_grad = this->p[i][j].mag;
      }
    }
  }
  
  for (i = 1; i <= this->Nr - 2; i++) {
    this->p[i][0].tx = this->p[i][1].tx;
    this->p[i][0].ty = this->p[i][1].ty;
    this->p[i][0].mag = this->p[i][1].mag;
    this->p[i][this->Nc - 1].tx = this->p[i][this->Nc - 2].tx;
    this->p[i][this->Nc - 1].ty = this->p[i][this->Nc - 2].ty;
    this->p[i][this->Nc - 1].mag = this->p[i][this->Nc - 2].mag;
  }
  
  for (j = 1; j <= this->Nc - 2; j++) {
    this->p[0][j].tx = this->p[1][j].tx;
    this->p[0][j].ty = this->p[1][j].ty;
    this->p[0][j].mag = this->p[1][j].mag;
    this->p[this->Nr - 1][j].tx = this->p[this->Nr - 2][j].tx;
    this->p[this->Nr - 1][j].ty = this->p[this->Nr - 2][j].ty;
    this->p[this->Nr - 1][j].mag = this->p[this->Nr - 2][j].mag;
  }
  
  this->p[0][0].tx = ( this->p[0][1].tx + this->p[1][0].tx ) / 2;
  this->p[0][0].ty = ( this->p[0][1].ty + this->p[1][0].ty ) / 2;
  this->p[0][0].mag = ( this->p[0][1].mag + this->p[1][0].mag ) / 2;
  this->p[0][this->Nc-1].tx = ( this->p[0][this->Nc-2].tx + this->p[1][this->Nc-1].tx ) / 2;
  this->p[0][this->Nc-1].ty = ( this->p[0][this->Nc-2].ty + this->p[1][this->Nc-1].ty ) / 2;
  this->p[0][this->Nc-1].mag = ( this->p[0][this->Nc-2].mag + this->p[1][this->Nc-1].mag ) / 2;
  this->p[this->Nr-1][0].tx = ( this->p[this->Nr-1][1].tx + this->p[this->Nr-2][0].tx ) / 2;
  this->p[this->Nr-1][0].ty = ( this->p[this->Nr-1][1].ty + this->p[this->Nr-2][0].ty ) / 2;
  this->p[this->Nr-1][0].mag = ( this->p[this->Nr-1][1].mag + this->p[this->Nr-2][0].mag ) / 2;
  this->p[this->Nr-1][this->Nc - 1].tx = (this->p[this->Nr - 1][this->Nc - 2].tx + this->p[this->Nr - 2][this->Nc - 1].tx)/2;
  this->p[this->Nr-1][this->Nc - 1].ty = (this->p[this->Nr - 1][this->Nc - 2].ty + this->p[this->Nr - 2][this->Nc - 1].ty)/2;
  this->p[this->Nr-1][this->Nc - 1].mag= (this->p[this->Nr - 1][this->Nc - 2].mag+this->p[this->Nr - 2][this->Nc - 1].mag)/2;

  normalize(this);
}
#endif


void set2(ETF *this, imatrix *image) {
  int i, j;
  double MAX_VAL = 1020.; 
  double v[2];
  mymatrix *tmp = mymatrix_new(image->Nr, image->Nc);
  imatrix *gmag;

  if (debug) fprintf(stderr, "this #0: %f\n", ETF_crc(this));
  if (debug) fprintf(stderr, "uninit tmp: %f\n", mymatrix_crc(tmp));
  
  this->max_grad = -1.;
  for (i = 1; i < image->Nr - 1; i++) { 
    for (j = 1; j < image->Nc - 1; j++) {
      double a,b,c,d,e;
      /* ////////////////////////////////////////////////////////////// */
 a=     this->p[i][j].tx = (image->p[i+1][j-1] + 2*(double)image->p[i+1][j] + image->p[i+1][j+1] 
       			- image->p[i-1][j-1] - 2*(double)image->p[i-1][j] - image->p[i-1][j+1]) / MAX_VAL;
 b=     this->p[i][j].ty = (image->p[i-1][j+1] + 2*(double)image->p[i][j+1] + image->p[i+1][j+1]
			  - image->p[i-1][j-1] - 2*(double)image->p[i][j-1] - image->p[i+1][j-1]) / MAX_VAL;
      /* /////////////////////////////////////////// */
 c=     v[0] = this->p[i][j].tx;
 d=     v[1] = this->p[i][j].ty;
      /* //////////////////////////////////////////// */
 e=     tmp->p[i][j] = sqrt(this->p[i][j].tx * this->p[i][j].tx + this->p[i][j].ty * this->p[i][j].ty);
 if (debug) if (i == 319 && j == 109) fprintf(stderr, "abcde: %f %f %f %f %f\n", a,b,c,d,e);
      if (tmp->p[i][j] > this->max_grad) {
	this->max_grad = tmp->p[i][j];
        if (debug) if (i == 319 && j == 109) fprintf(stderr, "max grad[%d][%d] = %f\n", i, j, this->max_grad);
      }
    }
  }
  if (debug) fprintf(stderr, "tmp0: %f\n", mymatrix_crc(tmp));
  if (debug) fprintf(stderr, "this #1: %f\n", ETF_crc(this));

  for (i = 1; i <= this->Nr - 2; i++) {
    tmp->p[i][0] = tmp->p[i][1];
    tmp->p[i][this->Nc - 1] = tmp->p[i][this->Nc - 2];
  }
  if (debug) fprintf(stderr, "tmp1: %f\n", mymatrix_crc(tmp));

  for (j = 1; j <= this->Nc - 2; j++) {
    tmp->p[0][j] = tmp->p[1][j];
    tmp->p[this->Nr - 1][j] = tmp->p[this->Nr - 2][j];
  }
  if (debug) fprintf(stderr, "tmp2: %f\n", mymatrix_crc(tmp));
  
  tmp->p[0][0] = ( tmp->p[0][1] + tmp->p[1][0] ) / 2;
  tmp->p[0][this->Nc-1] = ( tmp->p[0][this->Nc-2] + tmp->p[1][this->Nc-1] ) / 2;
  tmp->p[this->Nr-1][0] = ( tmp->p[this->Nr-1][1] + tmp->p[this->Nr-2][0] ) / 2;
  tmp->p[this->Nr - 1][this->Nc - 1] = ( tmp->p[this->Nr - 1][this->Nc - 2] + tmp->p[this->Nr - 2][this->Nc - 1] ) / 2;
  if (debug) fprintf(stderr, "tmp3: %f\n", mymatrix_crc(tmp));
  
  gmag = imatrix_new(this->Nr, this->Nc);
  
  /* normalize the magnitude */
  for (i = 0; i < this->Nr; i++) { 
    for (j = 0; j < this->Nc; j++) {
      tmp->p[i][j] /= this->max_grad;
      gmag->p[i][j] = round(tmp->p[i][j] * 255.0);
    }
  }

  if (debug) fprintf(stderr, "last tmp: %f\n", mymatrix_crc(tmp));
  if (debug) fprintf(stderr, "gmag: %d\n", imatrix_crc(gmag));
  if (debug) fprintf(stderr, "v: %f\n", v[0]+v[1]);
  
  for (i = 1; i < this->Nr - 1; i++) { 
    for (j = 1; j < this->Nc - 1; j++) {
      /* ////////////////////////////////////////////////////////////// */
      this->p[i][j].tx = (gmag->p[i+1][j-1] + 2*(double)gmag->p[i+1][j] + gmag->p[i+1][j+1] 
			  - gmag->p[i-1][j-1] - 2*(double)gmag->p[i-1][j] - gmag->p[i-1][j+1]) / MAX_VAL;
      this->p[i][j].ty = (gmag->p[i-1][j+1] + 2*(double)gmag->p[i][j+1] + gmag->p[i+1][j+1]
			  - gmag->p[i-1][j-1] - 2*(double)gmag->p[i][j-1] - gmag->p[i+1][j-1]) / MAX_VAL;
      /* /////////////////////////////////////////// */
      v[0] = this->p[i][j].tx;
      v[1] = this->p[i][j].ty;
      this->p[i][j].tx = -v[1];
      this->p[i][j].ty = v[0];
      /* //////////////////////////////////////////// */
      this->p[i][j].mag = sqrt(this->p[i][j].tx * this->p[i][j].tx + this->p[i][j].ty * this->p[i][j].ty);
      
      if (this->p[i][j].mag > this->max_grad) {
	this->max_grad = this->p[i][j].mag;
      }
    }
  }
  
  for (i = 1; i <= this->Nr - 2; i++) {
    this->p[i][0].tx = this->p[i][1].tx;
    this->p[i][0].ty = this->p[i][1].ty;
    this->p[i][0].mag = this->p[i][1].mag;
    this->p[i][this->Nc - 1].tx = this->p[i][this->Nc - 2].tx;
    this->p[i][this->Nc - 1].ty = this->p[i][this->Nc - 2].ty;
    this->p[i][this->Nc - 1].mag = this->p[i][this->Nc - 2].mag;
  }
  
  for (j = 1; j <= this->Nc - 2; j++) {
    this->p[0][j].tx = this->p[1][j].tx;
    this->p[0][j].ty = this->p[1][j].ty;
    this->p[0][j].mag = this->p[1][j].mag;
    this->p[this->Nr - 1][j].tx = this->p[this->Nr - 2][j].tx;
    this->p[this->Nr - 1][j].ty = this->p[this->Nr - 2][j].ty;
    this->p[this->Nr - 1][j].mag = this->p[this->Nr - 2][j].mag;
  }
  
  this->p[0][0].tx = ( this->p[0][1].tx + this->p[1][0].tx ) / 2;
  this->p[0][0].ty = ( this->p[0][1].ty + this->p[1][0].ty ) / 2;
  this->p[0][0].mag = ( this->p[0][1].mag + this->p[1][0].mag ) / 2;
  this->p[0][this->Nc-1].tx = ( this->p[0][this->Nc-2].tx + this->p[1][this->Nc-1].tx ) / 2;
  this->p[0][this->Nc-1].ty = ( this->p[0][this->Nc-2].ty + this->p[1][this->Nc-1].ty ) / 2;
  this->p[0][this->Nc-1].mag = ( this->p[0][this->Nc-2].mag + this->p[1][this->Nc-1].mag ) / 2;
  this->p[this->Nr-1][0].tx = ( this->p[this->Nr-1][1].tx + this->p[this->Nr-2][0].tx ) / 2;
  this->p[this->Nr-1][0].ty = ( this->p[this->Nr-1][1].ty + this->p[this->Nr-2][0].ty ) / 2;
  this->p[this->Nr-1][0].mag = ( this->p[this->Nr-1][1].mag + this->p[this->Nr-2][0].mag ) / 2;
  this->p[this->Nr - 1][this->Nc-1].tx = (this->p[this->Nr - 1][this->Nc - 2].tx + this->p[this->Nr - 2][this->Nc - 1].tx)/2;
  this->p[this->Nr - 1][this->Nc-1].ty = (this->p[this->Nr - 1][this->Nc - 2].ty + this->p[this->Nr - 2][this->Nc - 1].ty)/2;
  this->p[this->Nr - 1][this->Nc-1].mag= (this->p[this->Nr - 1][this->Nc - 2].mag+this->p[this->Nr - 2][this->Nc - 1].mag)/2;
	
  if (debug) fprintf(stderr, "this before norm: %f\n", ETF_crc(this));
  normalize(this);
  if (debug) fprintf(stderr, "this after norm - returning %f from set2\n", ETF_crc(this));

}

static void make_unit(double *vx, double *vy) {
  double mag = sqrt( *vx * *vx + *vy * *vy );
  if (mag != 0.0) { *vx /= mag; *vy /= mag; }
}

void normalize(ETF *this) {
  int i, j;
  for (i = 0; i < this->Nr; i++) { 
    for (j = 0; j < this->Nc; j++) {
      make_unit(&this->p[i][j].tx, &this->p[i][j].ty);
      this->p[i][j].mag /= this->max_grad;
    }
  }
}


void Smooth(ETF *this, int half_w, int M) {
  /*int MAX_GRADIENT = -1;*/
  int	i, j, k;
  double weight;
  int s, t;
  int x, y;
  double mag_diff;
  double v[2], w[2], g[2];
  double angle;
  double factor;
  int image_x = ETF_getRow(this);
  int image_y = ETF_getCol(this);
  ETF *e2 = New_ETF(image_x, image_y);

  ETF_copy(e2, this);

  if (debug) fprintf(stderr, "e2=this -> %f\n", ETF_crc(e2));

  for (k = 0; k < M; k++) {
    if (debug) fprintf(stderr, "Smooth: k = %d\n", k);
    /* ////////////////////// */
    /* horizontal*/
    for (j = 0; j < image_y; j++) {
      for (i = 0; i < image_x; i++) {
	g[0] = g[1] = 0.0;
	v[0] = this->p[i][j].tx;
	v[1] = this->p[i][j].ty;
	for (s = -half_w; s <= half_w; s++) {
	  /* ////////////////////////////////////// */
	  x = i+s; y = j;
	  if (x > image_x-1) x = image_x-1;
	  else if (x < 0) x = 0;
	  if (y > image_y-1) y = image_y-1;
	  else if (y < 0) y = 0;
	  /* ////////////////////////////////////// */
	  mag_diff = this->p[x][y].mag - this->p[i][j].mag; 
	  /* //////////////////////////////////////////////////// */
	  w[0] = this->p[x][y].tx;
	  w[1] = this->p[x][y].ty;
	  /* ////////////////////////////// */
	  factor = 1.0;
	  angle = v[0] * w[0] + v[1] * w[1];
	  if (angle < 0.0) {
	    factor = -1.0; 
	  }
	  weight = mag_diff + 1;  
	  /* //////////////////////////////////////////////////// */
	  g[0] += weight * this->p[x][y].tx * factor;
	  g[1] += weight * this->p[x][y].ty * factor;
	}
	make_unit(&g[0], &g[1]);
	e2->p[i][j].tx = g[0];
	e2->p[i][j].ty = g[1];
      }
    }

    if (debug) fprintf(stderr, "horizontal loop -> %f\n", ETF_crc(e2));

    ETF_copy(this, e2);/* ORDER SWAPPED!*/
    /* /////////////////////////////// */
    /* vertical */
    for (j = 0; j < image_y; j++) {
      for (i = 0; i < image_x; i++) {
	g[0] = g[1] = 0.0;
	v[0] = this->p[i][j].tx;
	v[1] = this->p[i][j].ty;
	for (t = -half_w; t <= half_w; t++) {
	  /* ////////////////////////////////////// */
	  x = i; y = j+t;
	  if (x > image_x-1) x = image_x-1;
	  else if (x < 0) x = 0;
	  if (y > image_y-1) y = image_y-1;
	  else if (y < 0) y = 0;
	  /* ////////////////////////////////////// */
	  mag_diff = this->p[x][y].mag - this->p[i][j].mag; 
	  /* //////////////////////////////////////////////////// */
	  w[0] = this->p[x][y].tx;
	  w[1] = this->p[x][y].ty;
	  /* ////////////////////////////// */
	  factor = 1.0;
	  /* ///////////////////////////// */
	  angle = v[0] * w[0] + v[1] * w[1];
	  if (angle < 0.0) factor = -1.0; 
	  /* /////////////////////////////////////////////////////// */
	  weight = mag_diff + 1; 
	  /* //////////////////////////////////////////////////// */
	  g[0] += weight * this->p[x][y].tx * factor;
	  g[1] += weight * this->p[x][y].ty * factor;
	}
	make_unit(&g[0], &g[1]);
	e2->p[i][j].tx = g[0];
	e2->p[i][j].ty = g[1];
      }
    }

    if (debug) fprintf(stderr, "vertical loop -> %f\n", ETF_crc(e2));

    ETF_copy(this, e2);
  }
  /* ////////////////////////////////////////// */
}