# -*- coding: utf-8 -*-
"""
Created on Sun Mar 18 00:49:54 2018

@author: chris

This code solves the NLS equation with the split-step Fourier method based 
on Govind P. Agrawal in March 2005 for the Nonlinear fiber optics book (Appendix B)
"""
import numpy as np
import matplotlib.pyplot as plt
import time

# ---Specify input parameters
distance = 150.0  # Enter fiber length (in units of L_c)=
# Normalized 2nd-dispersion: kappa=beta2*f^2*L/2):
# +ve for normal,-ve for anomalous*)
kappa = -0.001
sigma = 0.0  # Normalized 3rd-dispersion: sigma=beta3*f^3*L/6
G = 1.0  # small signal gain of Ramam Amp: G=g*L
Is = 1.0  # gain saturation parameter
alpha = 0.4  # Normalized fiber amplitude absorption coeff: alpha=l*L

# Nonlinear parameter n=')
# sqrt(L_D/L_NL)=sqrt(gamma*P0*T0^2/|beta2|) or QT: n=kappa^0.5
n = 2.0 ** 0.5

# ---Specify filter parameters
bdwidth = 2.0 * np.pi * 6.0
delta = 2.0 * np.pi * 0.5
a = np.log(np.sqrt(0.95))
perta = 0.3
pertfsr = 0.17
T = 0.2
t = np.sqrt(T)
r = 1.0j * np.sqrt(1.0 - T)

# ---Specify input parameters
mshape = -1.0  # m=0 for sech,m>0 for super-Gaussian=
chirp0 = 0.0  # % input pulse chirp (default value)

# P = 1/(gamma*L); (P is the ref peak power);
# uu = A/sqrt(P);
# z = z0/L_c; (z0 is the real length, L_c is the cavity length);
# tau = f*t; (t is the time of reference traveling frame);


# ---set simulation parameters
nt = 2 ** 13  # % FFT points  (powers of 2)
Tmax = 100.0  # (half) window size
stepno = 1 * round(20 * distance * n ** 2)  # No.of z steps to
dz = distance / stepno  # step size in z
dtau = (2.0 * Tmax) / nt  # step size in tau

Twin = 5.0
fmax = (1.0 / (2.0 * Tmax)) * nt / 2.0
fwin = 5.0

filterz = 0.5
plotz = 5

# ---tau and omega arrays
tau = np.arange(-nt / 2.0, nt / 2.0) * dtau  # temporal grid
# [(0:nt/2-1) (-nt/2:-1)]
omega = (np.pi / Tmax) * np.append(np.arange(0.0, nt / 2.0), np.arange(-nt / 2.0, 0.0))

# frequency grid
delaytau = dtau * np.arange(-round(Twin / dtau), round(Twin / dtau) + 1)

# Input Field profile
if mshape == 0:
    # ;% soliton
    uu = np.exp(-0.5j * chirp0 * tau ** 2.0) / np.cosh(tau)
elif mshape > 0:
    # super-Gaussian
    uu = np.exp(-0.5 * (1.0 + 1.0j * chirp0) * tau ** (2.0 * mshape))
else:
    # White noise
    uu = (np.random.randn(nt) + 1.0j * np.random.randn(nt)) * np.sqrt(0.5)

# temp = np.fft.fftshift(np.fft.ifft(uu) * (nt * dtau) / np.sqrt(2. * np.pi))
tempomega = np.fft.fftshift(omega)

# ---store dispersive phase shifts to speedup code
# % nonlinear phase factor
dispersion = np.exp(
    (-alpha + 1.0j * kappa * omega ** 2.0 + 1.0j * sigma * omega ** 3.0) * dz
)

# comb filter type
# original comb filter + BPF
# filtert = np.exp(-omega ** 2. / bdwidth ** 2.) * (t ** 2.) / \
#           (1 - r ** 2 * np.exp(-1.j * (omega + delta) + a))

# perturbated comb filter + BPF
filtert = (
    np.exp(-omega ** 2.0 / bdwidth ** 2.0 - perta * np.sin(0.5 * omega / pertfsr) ** 2)
    * (t ** 2)
    / (1.0 - r ** 2 * np.exp(-1.0j * (omega + delta) + a))
)

fig1 = plt.figure()
plt.plot(
    tempomega / (2.0 * np.pi), np.fft.fftshift(10.0 * np.log10(np.abs(filtert) ** 2.0))
)
plt.title("Perturbated comb filter + BPF")
plt.xlim(-fwin, fwin)
plt.xlim(-fwin, fwin)
plt.ylim(-30.0, 10.0)
plt.show()

# %*********[Beginning of MAIN Loop]***********
# % scheme:1/2N\[Rule]D\[Rule]1/2N;first half step nonlinear
temp = uu * np.exp(
    (1.0j * np.abs(uu) ** 2.0 + G / (1.0 + np.abs(uu) ** 2.0 / Is)) * dz / 2.0
)
# % note hhz/2

start_time = time.time()
time_used = time.time() - start_time
z = 0

# Realtime monitoring the simulation progress
plt.figure()
fig2, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2)
plt.tight_layout()
abs_temp = np.abs(temp)
autocorr0 = None
line1, = ax1.plot(abs_temp)
line2, = ax2.plot(abs_temp)
line3, = ax3.plot(abs_temp)
line4, = ax4.plot(abs_temp)
line5, = ax5.plot(abs_temp)
plt.ion()

for i in range(stepno):
    if round((z % 1 - filterz) / dz) == 0:
        ftemp = np.fft.ifft(temp) * filtert * dispersion
    else:
        ftemp = np.fft.ifft(temp) * dispersion

    uu = np.fft.fft(ftemp)
    temp = uu * np.exp(
        (1.0j * np.abs(uu) ** 2.0 + G / (1.0 + np.abs(uu) ** 2 / Is)) * dz
    )
    z = z + dz

    if round((z % plotz) / dz) == 0 or round(((z % plotz) - plotz) / dz) == 0:
        time_used = time.time() - start_time
        print("Z: " + str(z))
        # fig2.suptitle("i = " + str(i) + ", z = " + str(z) + " Time: " + str(time_used))

        line1.set_data(tau, np.abs(temp) ** 2.0)
        ax1.relim()
        ax1.autoscale_view(True, True, True)
        ax1.set_xlim([-Twin, Twin])
        ax1.set_title("Time domain (Magnified)")

        ftemp0 = np.fft.fftshift(ftemp * (nt * dtau) / np.sqrt(2 * np.pi))

        line2.set_data(
            tempomega / (2.0 * np.pi), 10.0 * np.log10(np.abs(ftemp0) ** 2.0)
        )
        #
        ax2.relim()
        ax2.autoscale_view(True, True, True)
        ax2.set_xlim([-fwin, fwin])
        ax2.set_title("Spectrum (Magnified)")

        line3.set_data(tau, np.abs(temp) ** 2)
        ax3.relim()
        ax3.autoscale_view(True, True, True)
        ax3.set_xlim([-Tmax, Tmax])
        ax3.set_title("Time domain (Full Scale)")

        line4.set_data(
            tempomega / (2.0 * np.pi), 10.0 * np.log10(np.abs(ftemp0) ** 2.0)
        )
        ax4.relim()
        ax4.autoscale_view(True, True, True)
        ax4.set_xlim([-fmax, fmax])
        ax4.set_title("Spectrum (Full Scale)")

        autocorr0 = np.fft.fftshift(
            np.fft.ifft(
                np.fft.fft(np.abs(temp) ** 2)
                * np.conjugate(np.fft.fft(np.abs(temp) ** 2))
            )
        )

        line5.set_data(tau, np.abs(autocorr0) / max(np.abs(autocorr0)))
        ax5.relim()
        ax5.autoscale_view(True, True, True)
        ax5.set_xlim([-Twin, Twin])
        ax5.set_title("Autocorrelation")

        plt.pause(0.1)
plt.ioff()
plt.show()


# Exporting results
fname = "result"  # file name
np.savetxt(
    fname + "_time.csv", (tau, np.real(uu), np.imag(uu), np.abs(uu) ** 2), delimiter=","
)
np.savetxt(
    fname + "_freq.csv",
    (
        tempomega / (2.0 * np.pi),
        np.real(temp),
        np.imag(temp),
        np.abs(temp) ** 2,
        np.fft.fftshift(np.abs(filtert) ** 2.0),
    ),
    delimiter=",",
)
np.savetxt(
    fname + "_autocorr.csv",
    (tau, np.abs(autocorr0), np.abs(autocorr0) / max(np.abs(autocorr0))),
    delimiter=",",
)