#!/usr/bin/env python
#-*- coding: UTF-8 -*-
##############################################
## Copyright @ Lianjin Wu (wulj@ihep.ac.cn)
## 436f7079526967687440204c69616e6a696e205775203c77756c6a40696865702e61632e636e3e
##############################################

import binascii, time, os, sys
#import subprocess
import ROOT as root
import particle, dataparser
import array

class Topology():
	def __init__(self, topodata = "", chain = "", ntrack = "", motherid = "", motherindex = "", trackid = "", trackindex = "", fromgenerate = "", primaryparticle = "", removespecials = []):
		"""
      author: Lianjin Wu (wulj@ihep.ac.cn)
			example: removespecials = ["mix == tix and abs(tid) == abs(mid) and abs(tid) in [11, 22]", "abs(tid) in [91, 4, 23]"]
		"""
		self._topodt = topodata
		self._markdecayids, self._decayids = list(), list()
		self._markparticleids, self._particleids = list(), list()
		self._respath = binascii.unhexlify("7775746f706f").decode()
		self._right = binascii.unhexlify("436f7079526967687440204c69616e6a696e20577520243c245c687265667b6d61696c746f3a2077756c6a40696865702e61632e636e7d7b77756c6a40696865702e61632e636e7d243e24").decode()
		if self._topodt != "":
			self.__parseTopoData()
		elif self._topodt == "" and chain != "" and ntrack != "" and motherid != "" and motherindex != "" and trackid != "" and trackindex != "":
			self._chain = chain
			self._ntrack, self._motherid, self._motherindex, self._trackid, self._trackindex, self._fromgenerate, self._primaryparticle = ntrack, motherid, motherindex, trackid, trackindex, fromgenerate, primaryparticle
			self._removespecials = removespecials
		else:
			print("You have to initialize topology by topo.dt or yourself")

	def markDecay(self, markid, decays):
		"""
      @ markid: setted ID used to mark the decay you wanted
      @ decays:  the included decays you wanted
          example [(motherid, childrenid), ... ], such as: [(433, 22, 221, 221), (221, 211, -211, 111, "fsr"/"yes"/-22), (111, 22, 22, "fsr"/"yes"/-22)]
          if motherid = 0, mother can be any particle
          "fsr"/"yes"/-22 represents final states gamma included.
		"""
		self._markdecayids.append(markid)
		sorteddecays = sorted(decays, key = lambda x: x[0])
		newdecays = list()
		for idecay in sorteddecays:
			fsr = 0 if "fsr" in idecay or "yes" in idecay or "-22" in idecay else None
			idecay = [decay for decay in idecay if decay not in ["fsr", "yes", -22] ]
			idecay = [idecay[0]] + sorted(idecay[1:])
			idecay = idecay + [fsr] if fsr != None else idecay
			idecay = tuple(idecay)
			newdecays.append(idecay)
		regulardecays = list()
		for idecay in set(newdecays): regulardecays.append((idecay, newdecays.count(idecay)))
		self._decayids.append(tuple(regulardecays))

	def markParticle(self, markid, particles):
		"""
      @ markid:     setted ID used to mark the particle you wanted
      @ particles:  included particles you wanted
           example [particleid], such as [433, 211]
		"""
		self._markparticleids.append(markid)
		regularparticles = list()
		for iparticle in set(particles): regularparticles.append((iparticle, particles.count(iparticle)))
		self._particleids.append(tuple(regularparticles))


	def makeTopo(self):
		"""
		"""
		print()
		sumstarttime = time.clock()
		print("====== Topology start at %s"%(time.strftime("%a %b %d %H:%M:%S %d/%m/%Y", time.localtime())) + " ======")

		starttime = time.clock()
		filledeventlist = self.__fillEventList()
		stoptime = time.clock()
		print("        >>>>>> Fill event end at %s"%(time.strftime("%a %b %d %H:%M:%S %d/%m/%Y", time.localtime())) + " [ cost %.2f sec ]"%(stoptime - starttime) + " <<<<<< ")

		starttime = time.clock()
		eventlist = self.__markEvent(filledeventlist)
		stoptime = time.clock()
		print("        >>>>>> Mark event end at %s"%(time.strftime("%a %b %d %H:%M:%S %d/%m/%Y", time.localtime())) + " [ cost %.2f sec ]"%(stoptime - starttime) + " <<<<<< ")

		starttime = time.clock()
		self.__showTopo(eventlist)
		stoptime = time.clock()
		print("        >>>>>> Show event end at %s"%(time.strftime("%a %b %d %H:%M:%S %d/%m/%Y", time.localtime())) + " [ cost %.2f sec ]"%(stoptime - starttime) + " <<<<<< ")

		starttime = time.clock()
		self.__restoreTopo(eventlist)
		stoptime = time.clock()
		print("        >>>>>> Restore event end at %s"%(time.strftime("%a %b %d %H:%M:%S %d/%m/%Y", time.localtime())) + " [ cost %.2f sec ]"%(stoptime - starttime) + " <<<<<< ")

		sumstoptime = time.clock()
		print("====== Topology end at %s"%(time.strftime("%a %b %d %H:%M:%S %d/%m/%Y", time.localtime())) + " [ cost %.2f sec ]"%(sumstoptime - sumstarttime) + " ======")
		print()

	def __parseTopoData(self):
		"""
			parse topo.dt to get the useful information
		"""
		data = dataparser.DataParser(self._topodt, [-1, -1, -1, -1, -1, 0, 0])
		treename = data.getDataByKey("tree_name")[0]
		rootnames = data.getDataByKey("root_path_name")
		cut = data.getDataByKey("cut")[0]
		cut = cut if cut != "--" else ""
		chain = root.TChain(treename)
		for rootname in rootnames: chain.Add(rootname)
		self._chain = chain.CopyTree(cut)

		branches = [x[0] for x in data.getDataByKeys(["ntrack", "trackIndex", "trackID", "motherIndex", "motherID", "from_generator", "primary_particle"])]
		self._ntrack, self._trackindex, self._trackid, self._motherindex, self._motherid, self._fromgenerate, self._primaryparticle = branches
		self._fromgenerate = self._fromgenerate if self._fromgenerate != "--" else ""
		self._primaryparticle == self._primaryparticle if self._primaryparticle != "--" else ""
    
		self._removespecials = [x.replace("&&", " and ").replace("||", " or ") for x in data.getDataByKey("remove_specials")] if data.getDataByKey("remove_specials")[0] != "--" else ""

		wdids = set([x for x in data.listRowKeys("wdID")])
		for wdid in wdids:
			if wdid == "--": continue
			mothers = data.getDataByKey(str(wdid) + "-!-" + "mother")
			children = data.getDataByKey(str(wdid) + "-!-" + "children")
			isfsrs = data.getDataByKey(str(wdid) + "-!-" + "isfsr")
			wdecays = list()
			for im in range(len(mothers)):
				wdecayinfo = [int(mothers[im])] + [x for x in map(int, children[im].strip().split("+"))]
				wdecayinfo = wdecayinfo if isfsrs[im] in ["--", "no", "nofsr"] else wdecayinfo + ["fsr"]
				wdecays.append(tuple(wdecayinfo))
			self.markDecay(int(wdid), wdecays)

		wpids = set([x for x in data.listRowKeys("wpID")])
		for wpid in wpids:
			if wpid == "--": continue
			wparticles = data.getDataByKey(str(wpid) + "-!-" + "particle")
			self.markParticle(int(wpid), [x for x in map(int, wparticles)])

	def __restoreTopo(self, markedeventlist):
		"""
      restore the IDs into root
		"""
		eventlist = sorted(markedeventlist, key = lambda x: x.getIndex())

		dcID, wdID, wpID, fsID = array.array("i", [-999]), array.array("i", [-999]), array.array("i", [-999]), array.array("i", [-999])

		rootfile = root.TFile("%s/wudecay.root"%(self._respath), "recreate")
		tree = self._chain.CloneTree(0)
		tree.Branch("dcID", dcID, "dcID/I")
		tree.Branch("fsID", fsID, "fsID/I")
		tree.Branch("wdID", wdID, "wdID/I")
		tree.Branch("wpID", wpID, "wpID/I")

		for i, event in enumerate(self._chain):
			if i != eventlist[i].getIndex():
				print("ERROR: the index in eventlist is disoders ")
				exit(-1)
			dcID[0] = eventlist[i].getProperty("decayID")
			fsID[0] = eventlist[i].getProperty("fsID")
			wdID[0] = eventlist[i].getProperty("wdecayID")
			wpID[0] = eventlist[i].getProperty("wparticleID")
			tree.Fill()
		tree.Write()
		rootfile.Close()


	def __showTopo(self, markedeventlist):
		"""
			show event by PDF
		"""
		if binascii.unhexlify("4c69616e6a696e205775").decode() not in self._right:
			print("ALARM: Do not change copyright")
			exit(-1)

		#subprocess.run("mkdir -p %s"%(self._respath), shell = True)
		print("               ... ... make directory ... ... ")
		os.system("mkdir -p %s"%(self._respath))
		print("               ... ... remove old tex ... ... ")
		#subprocess.run("rm -f wutopo/wudecay.tex wutopo/wudecay.pdf wutopo/wudecay.root", shell = True)
		os.system("rm -f wutopo/wudecay.*")

		eventlist = sorted(markedeventlist, key = lambda x: x.getProperty("ndecayID"), reverse = True)
		eventlistbydecayid, idecayid = list(), 0
		while idecayid < len(eventlist):
			eventlistbydecayid.append(eventlist[idecayid])
			idecayid = idecayid + eventlist[idecayid].getProperty("ndecayID")

		eventlist = sorted(markedeventlist, key = lambda x: x.getProperty("nfsID"), reverse = True)
		eventlistbyfsid, ifsid = list(), 0
		while ifsid < len(eventlist):
			eventlistbyfsid.append(eventlist[ifsid])
			ifsid = ifsid + eventlist[ifsid].getProperty("nfsID")

		print("               ... ... create new tex and fill ... ... ")
		with open("%s/wudecay.tex"%(self._respath), "w") as topofile:
			texheader = "%%%%  Developed by %s %%%%\n" %(self._right)
			texheader = texheader + "\\documentclass[4pt]{article}\n"
			texheader = texheader + "\\usepackage{multirow, lscape, geometry, supertabular, fancyhdr, lastpage, booktabs}\n"
			texheader = texheader + "\\usepackage{longtable, lscape, geometry, fancyhdr, booktabs, times, amsmath, mathptmx}\n"
			texheader = texheader + "\\usepackage[pdfstartview = FitH, colorlinks, urlcolor = blue, citecolor = blue, linkcolor = blue]{hyperref}\n"
			texheader = texheader + "\n"
			texheader = texheader + "\\geometry{left = 1.0cm, right = 1.0cm, top = 2.5cm, bottom = 2.5cm}\n"
			texheader = texheader + "\n"
			texheader = texheader + "\\pagestyle{fancy}\n"
			texheader = texheader + "\\fancyhead{}\n"
			texheader = texheader + "\\fancyhead[LE, RO]{\color{blue} Topology}\n"
			texheader = texheader + "\\fancyhead[RE, LO]{\color{blue} %s}\n"%(self._right)
			texheader = texheader + "\\fancyfoot{}\n"
			texheader = texheader + "\\fancyfoot[LE, RO]{\color{blue} \\thepage~/~\pageref{LastPage}}\n"
			#texheader = texheader + "\\fancyfoot[CE, CO]{\color{blue} \\}\n"
			texheader = texheader + "\\fancyfoot[RE, LO]{\color{blue} \\today}\n"
			texheader = texheader + "\n"
			texheader = texheader + "\\renewcommand{\\headrule}{}\n"
			texheader = texheader + "\n"
			texheader = texheader + "\\begin{document}\n"
			#texheader = texheader + "\\begin{landscape}\n"
			texheader = texheader + "\n"
			texheader = texheader + "\\title{\\huge\\bf TOPOLOGY}\n"
			texheader = texheader + "\\author{\\bf\\Large %s}\n" %(self._right)
			texheader = texheader + "\\date{\\bf\\Large\\today}\n"
			texheader = texheader + "\\maketitle\n"
			texheader = texheader + "\n"
			texheader = texheader + "\\vspace{3cm}\n"
			texheader = texheader + "\\centerline{\\Large \\bf STATEMENT} \n"
			texheader = texheader + "\\begin{table}[htbp] \n\\centering \n"
			texheader = texheader + "\\begin{tabular}{ll} \n"
			texheader = texheader + "\\hline\n"
			texheader = texheader + "dcID & decay chain ID \\\\ \n"
			texheader = texheader + "fsID & final state ID \\\\ \n"
			texheader = texheader + "wdID & wanted decay chain ID \\\\ \n"
			texheader = texheader + "wpID & wanted particle ID \\\\ \n"
			texheader = texheader + "\\hline\n"
			texheader = texheader + "\\end{tabular} \n"
			texheader = texheader + "\\end{table} \n"
			texheader = texheader + "\n"
			texheader = texheader + "\\clearpage\n"

			#textailer = "\\end{landscape}\n"
			textailer = "\\end{document}\n"

			topofile.write(texheader)

			topofile.write("\\centerline{\\large\\bf Decay Chain Summary (%i events)} \n"%(len(markedeventlist)))
			topofile.write("\\begin{longtable}{p{0.03\\textwidth}|p{0.70\\textwidth}|p{0.05\\textwidth}|p{0.03\\textwidth}|p{0.03\\textwidth}|p{0.03\\textwidth}}\n")
			topofile.write("\\hline\n")
			topofile.write("No. &Decay chain &Num &dcID &wdID &wpID \\\\ \n")
			topofile.write("\\hline\n")
			for i, event in enumerate(eventlistbydecayid):
				decaychain = "%i &"%(i)
				for track in event.getChildren():
					decaychain = decaychain + "$%s \\to"%(track.getName())
					for child in track.getChildren():
						decaychain = decaychain + " %s"%(child.getName())
					decaychain = decaychain + "$, "
				decaychain = decaychain + " &%i &%i &%i &%i"%(event.getProperty("ndecayID"), event.getProperty("decayID"), event.getProperty("wdecayID"), event.getProperty("wparticleID"))
				decaychain = decaychain + " \\\\ \n"
				topofile.write(decaychain)
				topofile.write("\\hline\n")
			topofile.write("\\end{longtable}\n")

			topofile.write("\\clearpage \n")

			topofile.write("\\centerline{\\large\\bf Final State Summary (%i events)} \n"%(len(markedeventlist)))
			topofile.write("\\begin{longtable}{p{0.03\\textwidth}|p{0.70\\textwidth}|p{0.07\\textwidth}|p{0.05\\textwidth}}\n")
			topofile.write("\\hline\n")
			topofile.write("No. &Final states &Num &fsID \\\\ \n")
			topofile.write("\\hline\n")
			for i, event in enumerate(eventlistbyfsid):
				fsnames = "%i &"%(i)
				for idfs in event.getProperty("fsMark").strip().split("_"):
					if idfs == "" or len(idfs) == 0: continue
					ptcfs = particle.Particle(id = idfs, index = -1)
					fsnames = fsnames + "$%s$" %(ptcfs.getName())
				fsnames = fsnames + " &%i &%i"%(event.getProperty("nfsID"), event.getProperty("fsID"))
				fsnames = fsnames + "\\\\ \n"
				topofile.write(fsnames)
				topofile.write("\\hline\n")
			topofile.write("\\end{longtable}\n")

			topofile.write(textailer)

		print("               ... ... finish new tex fill ... ... ")
		print("               ... ... compile tex ... ... ")
		compiletex = "cd %s \n pdflatex wudecay.tex > TmpDecayLog \n pdflatex wudecay.tex > TmpDecayLog \n pdflatex wudecay.tex > TmpDecayLog \n rm -f TmpDecayLog wudecay.aux  wudecay.log texput.log wudecay.out"%(self._respath)
		#subprocess.run([compiletex], shell = True)
		os.system(compiletex)
		print("               ... ... compile tex successfully ... ... ")


	def __fillEventList(self):
		"""
      restore the information to particle -> particle list
         eventlist -> event -> track

      Track: besides ID and Index, trackID trackIndex motherID motherIndex properties are available

      Decay relationship has been setted
		"""
		###### restore in particle
		if binascii.unhexlify("4c69616e6a696e205775").decode() not in self._right:
			print("ALARM: Do not change copyright")
			exit(-1)
		eventlist = list()
		for ievt, evt in enumerate(self._chain):
			ntrk = eval("evt.%s"%self._ntrack)
			ptcevt = particle.Particle(id = 0, index = ievt)
			for i in range(ntrk):
				mid, mix, tid, tix = eval("evt.%s[i]"%self._motherid), eval("evt.%s[i]"%self._motherindex), eval("evt.%s[i]"%self._trackid), eval("evt.%s[i]"%self._trackindex)

				if self._fromgenerate != "":
					if eval("evt.%s[i]"%self._fromgenerate) == 0: continue
				#if self._primaryparticle != "":	
				#	if eval("evt.%s[i]"%self._primaryparticle) == 1: continue

				isremovespecials = False
				if self._removespecials != "" and len(self._removespecials) != 0:
					for removes in self._removespecials:
						if eval(removes):
							isremovespecials = True
							break
				if isremovespecials: continue

				ptctrk = particle.Particle(id = tid, index = tix)
				ptctrk.setProperties(trackID = tid, trackIndex = tix, motherID = mid, motherIndex = mix) #### Track Particle properties
				ptcevt.setChild(ptctrk)
			eventlist.append(ptcevt)
		return eventlist

	def __markEvent(self, filledeventlist):
		"""
			properties decayMark  fsMark  decayID  fsID  wdecayID  wparticleID  are available
		"""
		if binascii.unhexlify("4c69616e6a696e205775").decode() not in self._right:
			print("ALARM: Do not change copyright")
			exit(-1)
		eventlist = list()
		for event in filledeventlist:
			ptcevt = particle.Particle(id = event.getID(), index = event.getIndex())
			decaymark, fsmark, fslist = "", "", list()
			## for relationship and restore
			for track in event.getChildren():
				children = [ptc for ptc in event.getChildren() if ptc.getProperty("motherIndex") == track.getProperty("trackIndex") ]
				for child in children: 
					if len(children) > 1 and track.getProperty("trackID") == child.getProperty("trackID"): continue
					if len([ptc for ptc in track.getChildren() if ptc.getProperty("trackIndex") == child.getProperty("trackIndex") ]) > 0: continue
					track.setChild(child)
				if track.getNChildren() != 0: 
					decaymark = decaymark + "%i"%track.getProperty("trackID")
					for child in track.getChildren():
						decaymark = decaymark + "_%i"%child.getProperty("trackID")
					decaymark = decaymark + "=="
					ptcevt.setChild(track)
				if track.getNChildren() <= 1:
					if abs(track.getProperty("trackID")) not in [11, 13, 22, 211, 321, 2212]: continue
					fslist.append(track)
			for track in sorted(fslist, key = lambda x: x.getProperty("trackID")): fsmark = fsmark + "%i_"%track.getProperty("trackID")
			ptcevt.setProperties(decayMark = decaymark, fsMark = fsmark)
			eventlist.append(ptcevt)

		### unique list for  decayMark and fsMark
		uniquedecaymarklist = sorted(list(set([event.getProperty("decayMark") for event in eventlist])), key = lambda keys:[ord(i) for i in keys])
		uniquefsmarklist = sorted(list(set([event.getProperty("fsMark") for event in eventlist])), key = lambda keys:[ord(i) for i in keys])

		for i in range(len(eventlist)):
			## decayID and fsID
			decayid = uniquedecaymarklist.index(eventlist[i].getProperty("decayMark"))
			fsid = uniquefsmarklist.index(eventlist[i].getProperty("fsMark"))

			## for wanted decay marks
			list_mother_children_fsr = list()
			list_mother_children_nofsr = list()
			list_children_fsr = list()
			list_children_nofsr = list()
			for track in eventlist[i].getChildren():
				isfsr = 0 if -22 in [child.getProperty("trackID") for child in track.getChildren()] else None
				mother_children_nofsr = [track.getProperty("trackID")] + sorted([child.getProperty("trackID") for child in track.getChildren() if child.getProperty("trackID") not in [-22]])
				mother_children_fsr = mother_children_nofsr + [isfsr] if isfsr != None else mother_children_nofsr
				children_nofsr = [0] + sorted([child.getProperty("trackID") for child in track.getChildren() if child.getProperty("trackID") not in [-22]])
				children_fsr = children_nofsr + [isfsr] if isfsr != None else children_nofsr

				list_mother_children_nofsr.append(tuple(mother_children_nofsr))
				list_mother_children_fsr.append(tuple(mother_children_fsr))
				list_children_nofsr.append(tuple(children_nofsr))
				list_children_fsr.append(tuple(children_fsr))

			markdecayid = -1
			for j in range(len(self._markdecayids)):
				ismarkdecayid = True
				for idecay in self._decayids[j]:
					#print(idecay, list_mother_children_fsr, list_mother_children_nofsr, list_children_fsr, list_children_nofsr)
					if list_mother_children_nofsr.count(idecay[0]) < idecay[1] and list_mother_children_fsr.count(idecay[0]) < idecay[1] and list_children_nofsr.count(idecay[0]) < idecay[1] and list_children_fsr.count(idecay[0]) < idecay[1]:
						ismarkdecayid = False
						break
				if ismarkdecayid:
					if markdecayid == -1: markdecayid = self._markdecayids[j]
					else: markdecayid = markdecayid*1000 + self._markdecayids[j]

			### wanted particles
			list_all_particles = [ptc.getProperty("trackID") for ptc in filledeventlist[i].getChildren()]
			markparticleid = -1
			for j in range(len(self._markparticleids)):
				ismarkparticleid = True
				for ptcid in self._particleids[j]:
					if list_all_particles.count(ptcid[0]) < ptcid[1]: 
						ismarkparticleid = False
						break
				if ismarkparticleid:
					if markparticleid == -1: markparticleid = self._markparticleids[j]
					else: markparticleid = markparticleid*1000 + self._markparticleids[j]

			eventlist[i].setProperties(wdecayID = markdecayid, wparticleID = markparticleid, decayID = decayid, fsID = fsid)

		### count items
		decayids = [event.getProperty("decayID") for event in eventlist]
		fsids = [event.getProperty("fsID") for event in eventlist]
		for i in range(len(eventlist)):
			ndecayid = decayids.count(eventlist[i].getProperty("decayID"))
			nfsid = fsids.count(eventlist[i].getProperty("fsID"))
			eventlist[i].setProperties(ndecayID = ndecayid, nfsID = nfsid)
	
		return eventlist
