#include "atlas_threads.h"
#include <atlas_lvl3.h>
typedef struct ATL_TGEMMt ATL_TGEMMt;
struct ATL_TGEMMt
{
   enum ATLAS_TRANS TA, TB;
   int m, n, k, lda, ldb, ldc;
   const void *A, *B;
   void *C;
   SCALAR alpha, beta;
};

#define Cmatadd  tname(tname(ATL_C,PRE),matadd)
#define Cptgemm  tname(tname(ATL_C,PRE),ptgemm)
#define Cptgemm2 tname(tname(ATL_C,PRE),ptgemm2)
#define Cgemm    tname(tname(ATL_C,PRE),gemm)

void Cmatadd(int M, int N, const TYPE *A, TYPE *C, int ldc)
/*
 * C += A;  I should generate and thread this, but I'm not going to bother
 * right now, since the code should only rarely partition along K.
 */
{
   const TYPE *stA = A + M, *st=A+M*N;
   ldc -= M;
#ifdef TCPLX
   ldc <<= 1;
   M <<= 1;
#endif

   do
   {
      do
      {
         *C++ += *A++;
      }
      while(A != stA);
      stA += M;
      C += ldc;
   }
   while (A != st);
}

void *Cptgemm2(void *ptr)
{
   ATL_TGEMMt *mp=ptr;
   Cgemm(mp->TA, mp->TB, mp->m, mp->n, mp->k, mp->alpha, mp->A,
         mp->lda, mp->B, mp->ldb, mp->beta, mp->C, mp->ldc);
   return(NULL);
}

void Cptgemm(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB, int M, int N,
             int K, const SCALAR alpha, const TYPE *A, int lda, const TYPE *B,
             int ldb, const SCALAR beta, TYPE *C, int ldc)
{
   int i, k, nblock, nblock1, neb, np = ATL_NTHREADS;
   enum {PartN, PartM, PartK, NoPart} part=NoPart; /* what dim to partition */
#ifdef TREAL
   const TYPE Zero = 0.0;
#else
   const TYPE Zero[2] = {0.0, 0.0};
#endif
   static pthread_t tp[ATL_NTHREADS];
   static pthread_attr_t attr;
   static ATL_TGEMMt MAT[ATL_NTHREADS];
   void *vp[ATL_NTHREADS];

   pthread_attr_init(&attr);
   #ifdef IBM_PT_ERROR
      pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_UNDETACHED);
   #endif
   #ifdef UseSystemScope
      pthread_attr_setscope(&attr, PTHREAD_SCOPE_SYSTEM);
   #endif
   i = NB * np;
   if (N >= i) part = PartN;
   else if (M >= i) part = PartM;
   else if (K >= i) part = PartK;
   if (part == NoPart) /* no dimension will go over procs */
   {
      if (M > N)
      {
         part = PartM;
         i = M / NB;
      }
      else
      {
         part = PartN;
         i = N / NB;
      }
      if (i < 2) part = NoPart;
      else np = i;
   }
   if (part == PartN)
   {
      nblock = N / NB;
      nblock1 = nblock / np;
      neb = nblock - nblock1*np;
      for (i=0; i != np; i++)
      {
         if (i < neb) k = MAT[i].n = (nblock1+1)*NB;
         else if (i == neb) k = MAT[i].n = N - nblock*NB + nblock1*NB;
         else k = MAT[i].n = nblock1*NB;
         MAT[i].TA = TA;        MAT[i].TB = TB;
         MAT[i].m = M;                              MAT[i].k = K;
         MAT[i].lda = lda;      MAT[i].ldb = ldb;   MAT[i].ldc = ldc;
         MAT[i].A = A;          MAT[i].B = B;       MAT[i].C = C;
         MAT[i].alpha = alpha;  MAT[i].beta = beta;
         if (TB == AtlasNoTrans) B += k * ldb;
         else B += k;
         C += k * ldc;
      }
   }
   else if (part == PartM)
   {
      nblock = M / NB;
      nblock1 = nblock / np;
      neb = nblock - nblock1*np;
      for (i=0; i != np; i++)
      {
         if (i < neb) k = MAT[i].m = (nblock1+1)*NB;
         else if (i == neb) k = MAT[i].m = M - nblock*NB + nblock1*NB;
         else k = MAT[i].m = nblock1*NB;
         MAT[i].TA = TA;        MAT[i].TB = TB;
                                MAT[i].n = N;       MAT[i].k = K;
         MAT[i].lda = lda;      MAT[i].ldb = ldb;   MAT[i].ldc = ldc;
         MAT[i].A = A;          MAT[i].B = B;       MAT[i].C = C;
         MAT[i].alpha = alpha;  MAT[i].beta = beta;
         if (TA == AtlasNoTrans) A += k;
         else A += k * lda;
         C += k;
      }
   }
   else if (part == PartK)
   {
      nblock = K / NB;
      nblock1 = nblock / np;
      neb = nblock - nblock1*np;
      for (i=0; i != np; i++)
      {
         if (i < neb) k = MAT[i].k = (nblock1+1)*NB;
         else if (i == neb) k = MAT[i].k = K - nblock*NB + nblock1*NB;
         else k = MAT[i].k = nblock1*NB;
         MAT[i].TA = TA;        MAT[i].TB = TB;
         MAT[i].m = M;          MAT[i].n = N;
         MAT[i].lda = lda;      MAT[i].ldb = ldb;
         MAT[i].A = A;          MAT[i].B = B;
         MAT[i].alpha = alpha;
         if (i < np-1)
         {
            vp[i] = malloc(ATL_MulBySize(M)*N + ATL_Cachelen);
            ATL_assert(vp[i] != NULL);
            MAT[i].C = ATL_AlignPtr(vp[i]);
            MAT[i].beta = Zero;
            MAT[i].ldc = M;
         }
         if (TA == AtlasNoTrans) A += k * lda;
         else A += k;
         if (TB == AtlasNoTrans) B += k;
         else B += k * ldb;
      }
      MAT[np-1].ldc = ldc;
      MAT[np-1].C = C;
      MAT[np-1].beta = beta;
   }
   else  /* No dimension big enough, just call one-processor code */
   {
      Cgemm(TA, TB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
      return;
   }
   np--;
   for (i=0; i != np; i++) pthread_create(&tp[i], &attr, Cptgemm2, &MAT[i]);
   Cptgemm2(&MAT[np]);
   if (part == PartK)
   {
/*
 *    Should do a hypercube combine, but blow that off for now; no big diff
 *    for small # of threads.  Probably need to change for 8 or 16 procs.
 */
      for (i=0; i != np; i++)
      {
         pthread_join(tp[i], NULL);
         Cmatadd(M, N, MAT[i].C, C, ldc);
         free(vp[i]);
      }
   }
   else for (i=0; i != np; i++) pthread_join(tp[i], NULL);
}

#undef Cgemm
#undef Cmatadd
#undef Cptgemm
#undef Cptgemm2

