#!/usr/bin/env python3
###########################################################################
##                                                                       ##
##                Centre for Speech Technology Research                  ##
##                   (Edinburgh University, UK) and                      ##
##                           Korin Richmond                              ##
##                         Copyright (c) 2004                            ##
##                         All Rights Reserved.                          ##
##                                                                       ##
##  Permission is hereby granted, free of charge, to use and distribute  ##
##  this software and its documentation without restriction, including   ##
##  without limitation the rights to use, copy, modify, merge, publish,  ##
##  distribute, sublicense, and/or sell copies of this work, and to      ##
##  permit persons to whom this work is furnished to do so, subject to   ##
##  the following conditions:                                            ##
##                                                                       ##
##   1. The code must retain the above copyright notice, this list of    ##
##      conditions and the following disclaimer.                         ##
##   2. Any modifications must be clearly marked as such.                ##
##   3. Original authors' names are not deleted.                         ##
##   4. The authors' names are not used to endorse or promote products   ##
##      derived from this software without specific prior written        ##
##      permission.                                                      ##
##                                                                       ##
##  EDINBURGH UNIVERSITY AND THE CONTRIBUTORS TO THIS WORK DISCLAIM      ##
##  ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL           ##
##  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT       ##
##  SHALL EDINBURGH UNIVERSITY OR THE CONTRIBUTORS BE LIABLE FOR         ##
##  ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES        ##
##  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER      ##
##  IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,       ##
##  ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF       ##
##  THIS SOFTWARE.                                                       ##
##                                                                       ##
###########################################################################
##                                                                       ##
## Script to assemble and normalise join cost coefficients.              ##
##                                                                       ##
## The normalisation here uses simple mean and standard deviation        ##
## This is equivalent to using a mahalanobis distance metric when        ##
## the join cost at runtime is simple Euclidean distance, but the        ##
## variance scaling is done ahead of time...                             ##
##                                                                       ##
## AUTHOR: KORIN RICHMOND                                                ##
###########################################################################

mfcc_shift = 0.002;
mfcc_frame1 = 0.005;
channel_names = ['norm_melcep_1',
                 'norm_melcep_2',
                 'norm_melcep_3',
                 'norm_melcep_4',
                 'norm_melcep_5',
                 'norm_melcep_6',
                 'norm_melcep_7',
                 'norm_melcep_8',
                 'norm_melcep_9',
                 'norm_melcep_10',
                 'norm_melcep_11',
                 'norm_melcep_N',
                 'norm_energy',
                 'norm_f0' ]

import sys
import os, re
import math

def listFiles( directory, pattern ):

    pattern = re.compile( pattern )

    print(" pattern compiled")

    dir_contents = os.listdir( directory )

    outList = []

    for entry in dir_contents:
        if pattern.match( entry ):
            outList.append( directory + entry )
            
    return outList


def usageAndQuit(progname):
    print("Usage: %s output_dir f0dir outdir '.*\.mfcc'" % progname)
    sys.exit(1)

try:
    estmoduledir = os.environ["EST_PYTHON"]
except KeyError:
    print("\n** environment variable EST_PYTHON is unset **\n")

sys.path.append( estmoduledir )
import est

if len(sys.argv) != 5:
    usageAndQuit(sys.argv[0])

outputdir = sys.argv[1]
f0dir = sys.argv[2]
fileDir = sys.argv[3] + '/'
filePat = sys.argv[4]

if not os.path.isdir(f0dir) or not os.path.isdir(outputdir) or not os.path.isdir(fileDir):
    usageAndQuit(sys.argv[0])

fileList = listFiles( fileDir, filePat )

mfcc_track = est.Track()
f0_track = est.Track()

# first calculate normalising parameters: means & standard deviations
# In case there is *a lot* of speech data in the voice to be processed
# we'll avoid loading all the data at once...
temp = est.Track()
temp.load(fileList[0], mfcc_shift, mfcc_frame1)
nchans = temp.num_channels()

mfcc_mean = est.FVector(nchans)
mfcc_sd = est.FVector(nchans)
mfcc_mean_acc = est.FVector(nchans)
mfcc_var_acc = est.FVector(nchans)
mfcc_len = 0.0

f0_chan = est.FVector()
f0_sum = 0.0
f0_ssum = 0.0
f0_len = 0.0

for file in fileList:
    
    mfcc_track.load(file, mfcc_shift, mfcc_frame1)
    
    bname = os.path.splitext(os.path.basename(file))[0]

    print(bname)

    f0filename = "%s/%s.f0" % (f0dir, bname)
    f0_track.load( f0filename )

    # mean/sd for mfcc and energy data
    est.meansd( mfcc_track, mfcc_mean, mfcc_sd )

    mfcc_mean *= mfcc_track.length()
    mfcc_sd *= mfcc_sd # elementwise multiply to get variance

    mfcc_sd *= (mfcc_track.length()-1)    
    mfcc_mean_acc += mfcc_mean
    mfcc_var_acc += mfcc_sd
    mfcc_len += mfcc_track.length()
    
    # mean/sd for f0 data, avoiding unvoiced values (represented by -1)
    f0_track.copy_channel_out( 0, f0_chan )
    for i in range(f0_chan.length()):
        f0=f0_chan[i]
        if f0 != -1 :
            f0_sum += f0
            f0_ssum += f0*f0
            f0_len += 1.0
            
# prepare means and standard deviations
mfcc_mean_acc /= mfcc_len
mfcc_mean = mfcc_mean_acc
mfcc_var_acc /= (mfcc_len-1)
mfcc_sd = est.sqrt(mfcc_var_acc)

f0_mean = f0_sum/f0_len
f0_sd = math.sqrt(((f0_len*f0_ssum)-(f0_sum*f0_sum))/((f0_len*(f0_len-1))))
f0_sd *= 4

# now finally load, normalise and save join cost coefficients
for file in fileList:
    mfcc_track.load(file, mfcc_shift, mfcc_frame1)
    bname = os.path.splitext(os.path.basename(file))[0]
    
    print(bname)

    est.normalise( mfcc_track, mfcc_mean, mfcc_sd, 1.0, 0.0 )

    # have to normalise f0 track by hand for now to avoid -1 values
    f0filename = "%s/%s.f0" % (f0dir, bname)
    f0_track.load( f0filename )
    f0_track.copy_channel_out( 0, f0_chan )
    for i in range(f0_chan.length()):
        f0=f0_chan[i]
        if f0 != -1 :
            f0_chan[i] = ((f0 - f0_mean)/f0_sd)+ 0.5;

    #combine and save
    mfcc_track.resize( mfcc_track.length(), nchans+1 )
    mfcc_track.copy_channel_in( nchans, f0_chan )

    for i in range(len(channel_names)):
        mfcc_track.set_channel_name( channel_names[i], i )
    
    outfile = '%s/%s.coef' % (outputdir, bname)
    
    mfcc_track.save( outfile, 'est_binary' )

                                                                                                                                                                                  