#!/opt/anaconda/envs/gpython2/bin/python
# -*- coding: utf-8 -*-

from __future__ import print_function
import os
import sys
import argparse


from openeye.oechem import *
from openeye.oeshape import *
from openeye.oegrid import *


def ChangeAtomCav(cav):
    """Replace Atom cavity Type and Isotope"""
    # print("{:>10} {:7} {:7} {:7} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("Index", "X", "Y", "Z", "Type", "Name", "AtomicNum", "Isotope", "FormCharge", "PartCharge", "IntType", "Radius", "Idx", "ImplHCount", "Hyb", "HvyValence", "HvyDegree", "ExplVal", "ExplHCount", "ExplDeg", "Degree", "TotalHCnt"))
    for atom in cav.GetAtoms():
        coords = OEFloatArray(3)
        cav.GetCoords(atom, coords)
        # print("{:>10} {:7.3f} {:7.3f} {:7.3f} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format(OEAtomBase.GetIdx(atom), coords[0], coords[1], coords[2], OEAtomBase.GetType(atom), OEAtomBase.GetName(atom), OEAtomBase.GetAtomicNum(atom), OEAtomBase.GetIsotope(atom), OEAtomBase.GetFormalCharge(atom), OEAtomBase.GetPartialCharge(atom), OEAtomBase.GetIntType(atom), OEAtomBase.GetRadius(atom), OEAtomBase.GetIdx(atom), OEAtomBase.GetImplicitHCount(atom), OEAtomBase.GetHyb(atom), OEAtomBase.GetHvyValence(atom), OEAtomBase.GetHvyDegree(atom), OEAtomBase.GetExplicitValence(atom), OEAtomBase.GetExplicitHCount(atom), OEAtomBase.GetExplicitDegree(atom), OEAtomBase.GetDegree(atom), OEAtomBase.GetTotalHCount(atom)))
        if atom.GetName() == 'CA':
            # print("hydrophobe")
            atom.SetType('C.3')
            # atom.SetName('C')
            atom.SetAtomicNum(6)
            atom.SetIsotope(13)
            atom.SetImplicitHCount(0)
        elif atom.GetName() == 'CZ':
            # print("aromatic")
            atom.SetType('C.ar')
            # atom.SetName('C')
            atom.SetAtomicNum(6)
            atom.SetIsotope(15)
            atom.SetImplicitHCount(0)
        elif atom.GetName() == 'N':
            # print("donor")
            atom.SetType('N.am')
            # atom.SetName('N')
            atom.SetAtomicNum(7)
            atom.SetIsotope(14)
            atom.SetImplicitHCount(0)
        elif atom.GetName() == 'O':
            # print("acceptor")
            atom.SetType('O.2')
            # atom.SetName('O')
            atom.SetAtomicNum(8)
            atom.SetIsotope(14)
            atom.SetImplicitHCount(0)
        elif atom.GetName() == 'NZ':
            # print("cation")
            atom.SetType('N.4')
            # atom.SetName('N')
            atom.SetAtomicNum(7)
            atom.SetIsotope(15)
            atom.SetImplicitHCount(0)
        elif atom.GetName() == 'OG':
            # print("donac")
            atom.SetType('O.3')
            # atom.SetName('O')
            atom.SetAtomicNum(8)
            atom.SetIsotope(15)
            atom.SetImplicitHCount(0)
        elif atom.GetName() == 'OD1':
            # print("anion")
            atom.SetType('O.co2')
            # atom.SetName('0')
            atom.SetAtomicNum(8)
            atom.SetIsotope(17)
            atom.SetImplicitHCount(0)
        elif atom.GetName() == 'Zn':
            # print("metal")
            atom.SetType('Zn')
            # atom.SetName('Zn')
            atom.SetAtomicNum(30)
            atom.SetIsotope(54)
            atom.SetImplicitHCount(0)
        elif atom.GetName() == 'DU':
            # print("Dummy")
            atom.SetType('H')
            atom.SetIsotope(2)
            atom.SetAtomicNum(1)
            atom.SetImplicitHCount(0)
        # print("{:>10} {:7.3f} {:7.3f} {:7.3f} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format(OEAtomBase.GetIdx(atom), coords[0], coords[1], coords[2], OEAtomBase.GetType(atom), OEAtomBase.GetName(atom), OEAtomBase.GetAtomicNum(atom), OEAtomBase.GetIsotope(atom), OEAtomBase.GetFormalCharge(atom), OEAtomBase.GetPartialCharge(atom), OEAtomBase.GetIntType(atom), OEAtomBase.GetRadius(atom), OEAtomBase.GetIdx(atom), OEAtomBase.GetImplicitHCount(atom), OEAtomBase.GetHyb(atom), OEAtomBase.GetHvyValence(atom), OEAtomBase.GetHvyDegree(atom), OEAtomBase.GetExplicitValence(atom), OEAtomBase.GetExplicitHCount(atom), OEAtomBase.GetExplicitDegree(atom), OEAtomBase.GetDegree(atom), OEAtomBase.GetTotalHCount(atom)))
    return(cav)

####################################################################################################
############################################### MAIN ###############################################
####################################################################################################
# def main(reffile, fitfile, outfile, tsvfile, cfffile, met, ini, sco, rad, rtyp, verb, maxstp):
def main(reffile, fitfile, outfile, tsvfile, cfffile, met, ini, sco, maxstp, keepsize, carbrad, verb):

    refcav = OEGraphMol()
    reffs = oemolistream(reffile)
    fitfs = oemolistream(fitfile)
    fitfs.SetConfTest(OEAbsCanonicalConfTest())
    if outfile:
        outfs = oemolostream(outfile)

    OEReadMolecule(reffs, refcav)
    refmol = ChangeAtomCav(refcav)
    
    ##############################################################
    ###################### Force Field ###########################
    ##############################################################
    try:
        cff = OEColorForceField()
        cff.Init(cfffile, False)
    except Exception as strerror:
        sys.exit("color force field file (.cff) {} is missing or not readable\n".format(cfffile,strerror))

    ##############################################################
    ###################### Reference Prep
    prep = OEOverlapPrep()
    # prep.SetAssignBondiRadii(True)   # Sets the state of the assign Bondi radii flag. If true, Bondi radii would be assigned to all the atom during the Prep.
    prep.SetColorForceField(cff)     # Set the color force field to be used. By default the ImplicitMillsDean color force field is used.
    # prep.SetAssignColor(False)        # Sets the state of the assign color flag. If true, color atoms would be assigned during the Prep.
    # prep.SetUseHydrogens(True)      # Sets the state of the use hydrogens flag. If true, explicit hydrogens atoms would be created during the Prep. If false, all hydrogen atoms would be removed.
    prep.Prep(refmol)
    ##############################################################

    ##############################################################
    ###################### set options
    colopt = OEColorOptions()
    colopt.SetColorForceField(cff)



# class OEShapeGridOptions
# The OEShapeGridOptions class defines the following public methods:
#         GetDerivativeType
#         GetGridSpacing
#         SetDerivativeType
#         SetGridSpacing

# class OEAnalyticOptions
# The OEAnalyticOptions class defines the following public methods:
#         GetExpType
#         GetProxyGridCutoff
#         GetUseProxyGrid
#         SetExpType
#         SetProxyGridCutoff
#         SetUseProxyGrid

# The OEShapeOptions class defines the following public methods:
#         GetCarbonRadius
#         GetRadiiApproximation
#         GetScoreType
#         SetCarbonRadius
#         SetRadiiApproximation
#         SetScoreType



    shopt = OEShapeOptions()

    # The default
    # shopt.SetRadiiApproximation(OEOverlapRadii_Carbon)
    # shopt.SetRadiiApproximation(OEOverlapRadii_All)


    shopt.SetCarbonRadius(carbrad)
    # print(shopt.GetCarbonRadius())





    options = OEOverlayOptions()
    ########## Maximum number of optimization iteration steps

    options.SetMaxOptSteps(maxstp)      # Set the maximum number of optimization iteration steps.
    ########## Initial Orientation 
    if ini == 'Inertial':
        options.SetStarts(OEInertialStarts())
    elif ini == 'AsIs':
        options.SetStarts(OEAsIsStarts())
    elif ini == 'Random':
        options.SetStarts(OERandomStarts())
    elif ini == 'InertialAtHeavyAtoms':
        options.SetStarts(OEAtAtomStarts())
    elif ini == 'Subrocs':
        options.SetStarts(OESubrocsStarts())
    # options.SetStarts(OECartesianStarts())
    # options.SetStarts(OEQuatStarts())

    ########## OEOverlapFunc
    # The OEOverlapFunc class defines an interface for combined shape and color overlap calculation between a reference object and a fit molecule. The OEOverlapFunc is a specilization of The OEOverlapFuncBase for combined shape and color overlap.
    if met == 'GridShape':
        options.SetOverlapFunc(OEGridShapeFunc())
    elif met == 'GridColor':
        options.SetOverlapFunc(OEGridColorFunc(colopt))
    elif met == 'GridShapeGridColor':
        options.SetOverlapFunc(OEOverlapFunc(OEGridShapeFunc(), OEGridColorFunc(colopt)))
    elif met == 'GridShapeAnalyticColor':
        options.SetOverlapFunc(OEOverlapFunc(OEGridShapeFunc(), OEAnalyticColorFunc(colopt)))
    elif met == 'GridShapeExactColor':
        options.SetOverlapFunc(OEOverlapFunc(OEGridShapeFunc(), OEExactColorFunc(colopt)))
    elif met == 'AnalyticShape':
        options.SetOverlapFunc(OEAnalyticShapeFunc())
    elif met == 'AnalyticColor':
        options.SetOverlapFunc(OEAnalyticColorFunc(colopt))
    elif met == 'AnalyticShapeAnalyticColor':
        options.SetOverlapFunc(OEOverlapFunc(OEAnalyticShapeFunc(), OEAnalyticColorFunc(colopt)))
    elif met == 'AnalyticShapeGridColor':
        options.SetOverlapFunc(OEOverlapFunc(OEAnalyticShapeFunc(), OEGridColorFunc(colopt)))
    elif met == 'AnalyticShapeExactColor':
        options.SetOverlapFunc(OEOverlapFunc(OEAnalyticShapeFunc(), OEExactColorFunc(colopt)))
    elif met == 'ExactShape':
        options.SetOverlapFunc(OEExactShapeFunc(shopt))
    elif met == 'ExactColor':
        options.SetOverlapFunc(OEExactColorFunc(colopt))
    elif met == 'ExactShapeExactColor':
        options.SetOverlapFunc(OEOverlapFunc(OEExactShapeFunc(shopt), OEExactColorFunc(colopt)))
    elif met == 'ExactShapeGridColor':
        options.SetOverlapFunc(OEOverlapFunc(OEExactShapeFunc(shopt), OEGridColorFunc(colopt)))
    elif met == 'ExactShapeAnalyticColor':
        options.SetOverlapFunc(OEOverlapFunc(OEExactShapeFunc(shopt), OEAnalyticColorFunc(colopt)))
    # options.SetOverlapFunc(OEOverlapFunc())
    # options.SetOverlapFunc(OEShapeFunc())


    ##############################################################
    ###################### overlay
    overlay = OEOverlay(options)
    overlay.SetupRef(refmol)
    
    
    refidx = 1
    if verb:
        print("{0:>15}\t{1:>15}\t{2}\t{3}\t{4}\t{5}\t{6}\t{7}\t{8}\t{9}\t{10}\t{11}\t{12}\t{13}\t{14}\t{15}\t{16}".format("Ref", "Fit", "TanCombo", "TanShape", "TanColor", "RefTvCombo", "RefTvShape", "RefTvColor", "RefSelfColor", "RefSelfOverlap", "FitTvCombo", "FitTvShape", "FitTvColor", "FitSelfColor", "FitSelfOverlap", "ColorOverlap", "Overlap"))
    if tsvfile:
        tsvhdl = open(tsvfile, 'w')
        tsvhdl.write("{0:>15}\t{1:>15}\t{2}\t{3}\t{4}\t{5}\t{6}\t{7}\t{8}\t{9}\t{10}\t{11}\t{12}\t{13}\t{14}\t{15}\t{16}\n".format("Ref", "Fit", "TanCombo", "TanShape", "TanColor", "RefTvCombo", "RefTvShape", "RefTvColor", "RefSelfColor", "RefSelfOverlap", "FitTvCombo", "FitTvShape", "FitTvColor", "FitSelfColor", "FitSelfOverlap", "ColorOverlap", "Overlap"))

    for fitmol in fitfs.GetOEMols():
        prep.Prep(fitmol)

        resCount = 0
        scoreiter = OEBestOverlayScoreIter()

        # nb = 0
        # for res in overlay.Overlay(fitmol):
        #     for score in res.GetScores():
        #         nb += 1
        #         outmol = OEGraphMol(fitmol.GetConf(OEHasConfIdx(score.GetFitConfIdx())))
        #         score.Transform(outmol)
        #         OERemoveColorAtoms(outmol)
        #         OEWriteMolecule(outfs, outmol)
        #         print("{} {} {} {}".format(nb, score.GetFitConfIdx(), refidx, score.GetTanimotoCombo()))

        ############################ Sorting Method #########################################
        if sco == 'TanimotoCombo':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestTanimotoCombo())
        elif sco == 'Tanimoto':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestTanimoto())
        elif sco == 'ColorTanimoto':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestColorTanimoto())
        elif sco == 'ScaledColor':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestScaledColor())
        elif sco == 'ComboScore':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestComboScore())
        elif sco == 'FitColorTversky':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestFitColorTversky())
        elif sco == 'FitTversky':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestFitTversky())
        elif sco == 'FitTverskyCombo':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestFitTverskyCombo())
        elif sco == 'Overlap':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestOverlap())
        elif sco == 'RefColorTversky':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestRefColorTversky())
        elif sco == 'RefTversky':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestRefTversky())
        elif sco == 'RefTverskyCombo':
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestRefTverskyCombo())
        else:
            OESortOverlayScores(scoreiter, overlay.Overlay(fitmol), OEHighestTanimotoCombo())            
        
        print("------------------------------------------------------------------------------------------------------")
        for score in scoreiter:
            outmol = OEGraphMol(fitmol.GetConf(OEHasConfIdx(score.GetFitConfIdx())))
            score.Transform(outmol)
            # print("{:>10} {:7} {:7} {:7} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("Index", "X", "Y", "Z", "Type", "Name", "AtomicNum", "Isotope", "FormCharge", "PartCharge", "IntType", "Radius", "Idx", "ImplHCount", "Hyb", "HvyValence", "HvyDegree", "ExplVal", "ExplHCount", "ExplDeg", "Degree", "TotalHCnt"))
            # for atom in outmol.GetAtoms():
            #     coords = OEFloatArray(3)
            #     refmol.GetCoords(atom, coords)
            #     print("{:>10} {:7.3f} {:7.3f} {:7.3f} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format(OEAtomBase.GetIdx(atom), coords[0], coords[1], coords[2], OEAtomBase.GetType(atom), OEAtomBase.GetName(atom), OEAtomBase.GetAtomicNum(atom), OEAtomBase.GetIsotope(atom), OEAtomBase.GetFormalCharge(atom), OEAtomBase.GetPartialCharge(atom), OEAtomBase.GetIntType(atom), OEAtomBase.GetRadius(atom), OEAtomBase.GetIdx(atom), OEAtomBase.GetImplicitHCount(atom), OEAtomBase.GetHyb(atom), OEAtomBase.GetHvyValence(atom), OEAtomBase.GetHvyDegree(atom), OEAtomBase.GetExplicitValence(atom), OEAtomBase.GetExplicitHCount(atom), OEAtomBase.GetExplicitDegree(atom), OEAtomBase.GetDegree(atom), OEAtomBase.GetTotalHCount(atom)))


            OERemoveColorAtoms(outmol)
            if outfile:
                OERemoveColorAtoms(outmol)
                OEWriteMolecule(outfs, outmol)
            if tsvfile:
                tsvhdl.write("{:>15}\t{:>15}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\n".format(refmol.GetTitle(), outmol.GetTitle(), score.GetTanimotoCombo(), score.GetTanimoto(), score.GetColorTanimoto(), score.GetRefTverskyCombo(), score.GetRefTversky(), score.GetRefColorTversky(), score.GetRefSelfColor(), score.GetRefSelfOverlap(), score.GetFitTverskyCombo(), score.GetFitTversky(), score.GetFitColorTversky(),score.GetFitSelfColor(), score.GetFitSelfOverlap(), score.GetColorScore(), score.GetOverlap()))
            if verb:
                print("{:>15}\t{:>15}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}".format(refmol.GetTitle(), outmol.GetTitle(), score.GetTanimotoCombo(), score.GetTanimoto(), score.GetColorTanimoto(), score.GetRefTverskyCombo(), score.GetRefTversky(), score.GetRefColorTversky(), score.GetRefSelfColor(), score.GetRefSelfOverlap(), score.GetFitTverskyCombo(), score.GetFitTversky(), score.GetFitColorTversky(),score.GetFitSelfColor(), score.GetFitSelfOverlap(), score.GetColorScore(), score.GetOverlap()))
            resCount += 1 
            if resCount == keepsize:
                break

    #     # print(resCount, "results returned")
    # fitfs.rewind()
    if tsvfile:
        tsvhdl.close()


if __name__ == "__main__":

    # options.SetOverlapFunc(OEOverlapFunc(OEAnalyticShapeFunc(), OEAnalyticColorFunc()))
    # options.SetOverlapFunc(OEOverlapFunc(OEExactShapeFunc(), OEExactColorFunc()))
    # options.SetOverlapFunc(OEOverlapFunc(OEGridShapeFunc(), OEGridColorFunc()))


    m = ['GridShape', 'GridColor', 'GridShapeGridColor', 'GridShapeAnalyticColor', 'GridShapeExactColor', 'AnalyticShape', 'AnalyticColor', 'AnalyticShapeAnalyticColor', 'AnalyticShapeGridColor', 'AnalyticShapeExactColor', 'ExactShape', 'ExactColor', 'ExactShapeExactColor', 'ExactShapeGridColor', 'ExactShapeAnalyticColor']
    i = ['Inertial', 'AsIs', 'Random', 'InertialAtHeavyAtoms', 'Subrocs']
    s = ['TanimotoCombo', 'Tanimoto', 'ColorTanimoto', 'ScaledColor', 'ComboScore', 'FitColorTversky', 'FitTversky', 'FitTverskyCombo', 'Overlap', 'RefColorTversky', 'RefTversky', 'RefTverskyCombo']

    parser = argparse.ArgumentParser(description = 'Compute similarity between a Reference Cavity (ref) and a Set of ligands (fit)', epilog = "Good Luck !")
    parser.add_argument('-r', '--ref', help = 'Path to reference <cavity.mol2> file', required = True)
    parser.add_argument('-c', '--fit', help = 'Path to (multimol) <ligand.mol2> file', required = True)
    parser.add_argument('-o', '--output', help = 'Path to overlap results <output.mol2> file', required = False, default = False)
    parser.add_argument('-t', '--tsv', help = 'Path to results score <score.tsv> file', required = False, default = False)
    parser.add_argument('-f', '--cff', help = 'Path to color force field file <FILE.cff>', required = True)
    parser.add_argument('-m', '--method', help = 'Algorithm used to calculate overlap (default="GridShapeGridColor")', required = False, default = 'GridShapeGridColor', choices = m)
    parser.add_argument('-i', '--ini', help = 'Starting point of the optimization (default="Inertial"', required = False, default = 'Inertial', choices = i)
    parser.add_argument('-s', '--score', help = 'Scoring function for sorting overlays (default="TanimotoCombo"', required = False, default = 'TanimotoCombo', choices = s)
    parser.add_argument('-a', '--maxstep', help = 'Set the maximum number of optimization iteration steps (default=200).', required = False, type = int, default = 200)
    parser.add_argument('-n', '--keepnb', help = 'Number to keep(default=1)', required = False, type = int, default = 1)
    parser.add_argument('-v', '--verbose', help = 'verbose (default=False)', action = 'store_true', required = False, default = False)
    parser.add_argument('-u', '--radius', help = 'Set the radius for atom overlap (default=1.7). Only valid for "ExactShape"', required = False, type = float, default = 1.7)
    parser.add_argument('--version', action = 'version', version = '%(prog)s 9.0')
    args = parser.parse_args()

    try:
        refpath = os.path.abspath(args.ref)
    except IOError as (errno, strerror):
        print("I/O error({0}): {1}".format(errno, strerror))

    try:
        fitpath = os.path.abspath(args.fit)
    except IOError as (errno, strerror):
        print("I/O error({0}): {1}".format(errno, strerror))

    try:
        cffpath = os.path.abspath(args.cff)
    except IOError as (errno, strerror):
        print("I/O error({0}): {1}".format(errno, strerror))

    if args.output:
        outpath = os.path.abspath(args.output)
    else:
        outpath = False
    if args.tsv:
        tsvpath = os.path.abspath(args.tsv)
    else:
        tsvpath = False


    sys.exit(main(refpath, fitpath, outpath, tsvpath, cffpath, args.method, args.ini, args.score, int(args.maxstep), int(args.keepnb), float(args.radius), args.verbose))


