// LSDecAmps class definition file. -*- C++ -*-
// Copyright 2012 Bertram Kopf

#include <getopt.h>
#include <fstream>
#include <string>
#include <mutex>

#include "PwaUtils/LSDecAmps.hh"
#include "qft++/relativistic-quantum-mechanics/Utils.hh"
#include "ErrLogger/ErrLogger.hh"
#include "PwaUtils/DataUtils.hh"
#include "PwaUtils/IsobarLSDecay.hh"
//#include "PwaUtils/XdecAmpRegistry.hh"
#include "Particle/Particle.hh"


LSDecAmps::LSDecAmps(boost::shared_ptr<IsobarLSDecay> theDec) :
  AbsXdecAmp(theDec)
  ,_JPCLSs(theDec->JPCLSAmps())
  ,_factorMag(1.)
{
  if(_JPCLSs.size()>0) _factorMag=1./sqrt(_JPCLSs.size());
  Particle* daughter1=_decay->daughter1Part();
  Particle* daughter2=_decay->daughter2Part();
  _parityFactor=daughter1->theParity()*daughter2->theParity()*pow(-1,_JPCPtr->J-daughter1->J()-daughter2->J());
  Info << "_parityFactor=\t" << _parityFactor << endmsg;
  fillCgPreFactor(); 
}

LSDecAmps::LSDecAmps(boost::shared_ptr<AbsDecay> theDec) :
  AbsXdecAmp(theDec)
{
  Particle* daughter1=_decay->daughter1Part();
  Particle* daughter2=_decay->daughter2Part();
  _parityFactor=daughter1->theParity()*daughter2->theParity()*pow(-1,_JPCPtr->J-daughter1->J()-daughter2->J()); 
  Info << "_parityFactor=\t" << _parityFactor << endmsg;
  fillCgPreFactor();  
}

LSDecAmps::~LSDecAmps()
{
}


complex<double> LSDecAmps::XdecPartAmp(Spin lamX, Spin lamDec, short fixDaughterNr, EvtData* theData, Spin lamFs, AbsXdecAmp* grandmaAmp){

  Spin lam1Min=-_Jdaughter1;
  Spin lam1Max= _Jdaughter1;
  Spin lam2Min=-_Jdaughter2;
  Spin lam2Max=_Jdaughter2;

  if(fixDaughterNr == 1){
     lam1Min = lam1Max = lamDec;
  }
  else if(fixDaughterNr == 2){
     lam2Min = lam2Max = lamDec;
  }
  else{
     Alert << "Invalid fixDaughterNr in XdecPartAmp." << endmsg;
  }

  if(_enabledlamFsDaughter1){
    lam1Min=lamFs;
    lam1Max=lamFs;
  }
  else if(_enabledlamFsDaughter2){
    lam2Min=lamFs;
    lam2Max=lamFs;
  }

  complex<double> result=lsLoop(lamX, theData, lam1Min, lam1Max, lam2Min, lam2Max, false);

  return result;
}




complex<double> LSDecAmps::XdecAmp(Spin lamX, EvtData* theData, Spin lamFs, AbsXdecAmp* grandmaAmp){

  complex<double> result(0.,0.);  
  if( fabs(lamX) > _JPCPtr->J) return result; 

  int evtNo=theData->evtNo;
  std::string currentKey=_absDyn->grandMaKey(grandmaAmp);
 
  if ( _cacheAmps && !_recalculate){
    result=_cachedGrandmaAmpMap[currentKey][evtNo][lamX][lamFs];
    return result;
  }

  //  Spin lam1Min=-_Jdaughter1;
  Spin lam1Min=-_Jdaughter1;
  Spin lam1Max= _Jdaughter1;
  Spin lam2Min=-_Jdaughter2;
  Spin lam2Max=_Jdaughter2;

  if(_enabledlamFsDaughter1){
    lam1Min=lamFs;
    lam1Max=lamFs;
  }
  else if(_enabledlamFsDaughter2){
    lam2Min=lamFs;
    lam2Max=lamFs;
  }

  
  result=lsLoop(lamX, theData, lam1Min, lam1Max, lam2Min, lam2Max, true, lamFs );

  result*=_absDyn->eval(theData, grandmaAmp);

  if ( _cacheAmps){
     theMutex.lock();
     //      _cachedAmpMap[evtNo][lamX][lamFs]=result;
     _cachedGrandmaAmpMap[currentKey][evtNo][lamX][lamFs]=result;
     theMutex.unlock();
  }

  return result;
}


complex<double> LSDecAmps::lsLoop(Spin lamX, EvtData* theData, Spin lam1Min, Spin lam1Max, Spin lam2Min, Spin lam2Max, bool withDecs, Spin lamFs ){
  complex<double> result(0.,0.);
  std::vector< boost::shared_ptr<const JPCLS> >::iterator it;
  for (it=_JPCLSs.begin(); it!=_JPCLSs.end(); ++it){

    double theMag=_currentParamMags[*it];
    double thePhi=_currentParamPhis[*it];
    complex<double> expi(cos(thePhi), sin(thePhi));

    for(Spin lambda1=lam1Min; lambda1<=lam1Max; ++lambda1){
      for(Spin lambda2=lam2Min; lambda2<=lam2Max; ++lambda2){
	Spin lambda = lambda1-lambda2;
	if( fabs(lambda)>(*it)->J || fabs(lambda)>(*it)->S) continue;
	complex<double> amp = theMag*expi*_cgPreFactor[*it][lambda1][lambda2]*conj( theData->WignerDsString[_wignerDKey][(*it)->J][lamX][lambda]);

      	if(withDecs) amp *=daughterAmp(lambda1, lambda2, theData, lamFs, this);
	result+=amp;
      }
    }
  }
  return result;
} 


void  LSDecAmps::getDefaultParams(fitParams& fitVal, fitParams& fitErr){

  std::map< boost::shared_ptr<const JPCLS>, double, pawian::Collection::SharedPtrLess > currentMagValMap;
  std::map< boost::shared_ptr<const JPCLS>, double, pawian::Collection::SharedPtrLess > currentPhiValMap;
  std::map< boost::shared_ptr<const JPCLS>, double, pawian::Collection::SharedPtrLess > currentMagErrMap;
  std::map< boost::shared_ptr<const JPCLS>, double, pawian::Collection::SharedPtrLess > currentPhiErrMap;

  std::vector< boost::shared_ptr<const JPCLS> >::const_iterator itLS;
  for(itLS=_JPCLSs.begin(); itLS!=_JPCLSs.end(); ++itLS){
    currentMagValMap[*itLS]=_factorMag;
    currentPhiValMap[*itLS]=0.;
    currentMagErrMap[*itLS]=_factorMag;
    currentPhiErrMap[*itLS]=0.3;
  }

  fitVal.Mags[_key]=currentMagValMap;
  fitVal.Phis[_key]=currentPhiValMap;
  fitErr.Mags[_key]=currentMagErrMap;
  fitErr.Phis[_key]=currentPhiErrMap;

  _absDyn->getDefaultParams(fitVal, fitErr);


  if(!_daughter1IsStable) _decAmpDaughter1->getDefaultParams(fitVal, fitErr);
  if(!_daughter2IsStable) _decAmpDaughter2->getDefaultParams(fitVal, fitErr);  
}

void LSDecAmps::print(std::ostream& os) const{
  return; //dummy
}


bool LSDecAmps::checkRecalculation(fitParams& theParamVal){
  _recalculate=false;

   if(_absDyn->checkRecalculation(theParamVal)) _recalculate=true; 

   if(!_daughter1IsStable) {
     if(_decAmpDaughter1->checkRecalculation(theParamVal)) _recalculate=true;
   }
   if(!_daughter2IsStable){
     if(_decAmpDaughter2->checkRecalculation(theParamVal)) _recalculate=true;
   }

   if(!_recalculate){
     std::map< boost::shared_ptr<const JPCLS>, double, pawian::Collection::SharedPtrLess >& magMap=theParamVal.Mags[_key];
     std::map< boost::shared_ptr<const JPCLS>, double, pawian::Collection::SharedPtrLess >& phiMap=theParamVal.Phis[_key];
     
     std::vector< boost::shared_ptr<const JPCLS> >::iterator it;
     for (it=_JPCLSs.begin(); it!=_JPCLSs.end(); ++it){
       double theMag=magMap[*it];
       double thePhi=phiMap[*it];
       
       if ( fabs(theMag - _currentParamMags[*it])  > 1.e-10 ){
	 _recalculate=true;
	 return _recalculate;
       }
       if ( fabs(thePhi - _currentParamPhis[*it])  > 1.e-10 ){
	 _recalculate=true;
	 return _recalculate;
       }
     }
   }

   return _recalculate;
}
 

void  LSDecAmps::updateFitParams(fitParams& theParamVal){
   std::map< boost::shared_ptr<const JPCLS>, double, pawian::Collection::SharedPtrLess >& magMap=theParamVal.Mags[_key];
   std::map< boost::shared_ptr<const JPCLS>, double, pawian::Collection::SharedPtrLess >& phiMap=theParamVal.Phis[_key];

   std::vector< boost::shared_ptr<const JPCLS> >::iterator it;
   for (it=_JPCLSs.begin(); it!=_JPCLSs.end(); ++it){
     double theMag=magMap[*it];
     double thePhi=phiMap[*it];
     _currentParamMags[*it]=theMag;
     _currentParamPhis[*it]=thePhi;
   }

   _absDyn->updateFitParams(theParamVal);

  if(!_daughter1IsStable) _decAmpDaughter1->updateFitParams(theParamVal);
  if(!_daughter2IsStable) _decAmpDaughter2->updateFitParams(theParamVal);

}

void  LSDecAmps::fillCgPreFactor(){

  std::vector< boost::shared_ptr<const JPCLS> >::iterator it;
  for (it=_JPCLSs.begin(); it!=_JPCLSs.end(); ++it){
    for(Spin lambda1=-_Jdaughter1; lambda1<=_Jdaughter1; ++lambda1){
      for(Spin lambda2=-_Jdaughter2; lambda2<=_Jdaughter2; ++lambda2){
	Spin lambda = lambda1-lambda2;
	if( fabs(lambda)>(*it)->J || fabs(lambda)>(*it)->S) continue;

	_cgPreFactor[*it][lambda1][lambda2]=sqrt(2.*(*it)->L+1)
	  *Clebsch((*it)->L, 0, (*it)->S, lambda, (*it)->J, lambda)
	  *Clebsch(_Jdaughter1, lambda1, _Jdaughter2, -lambda2, (*it)->S, lambda  );
      }
    }
  }
}
