%% spike_noise_rate.m
% Simulate auditory-nerve spike trains for a ~100 Hz tone at two intensities,
% and illustrate how noise transforms a staircase rate–level function into
% a smooth, useful code for stimulus level.
%
% This script reproduces Fig. fig_fig8:
% A, B: example spike trains at two levels.
% C: rate–level curves with and without noise.

clear; close all; clc;

%% Check for optional helper functions
hasNicegraph = ~isempty(which('nicegraph'));
hasSavegraph = ~isempty(which('savegraph'));
hasCbrewer  = ~isempty(which('cbrewer'));
hasBftext = ~isempty(which('bf_text'));

if hasCbrewer
	% Use custom cbrewer colormap if available
	col = cbrewer('qual','Pastel1',3);
	col = col(2,:);
else
	% Fallback to built-in colormap
	colors = lines(3);
	col = col(2);
end
%% General parameters
rng(0);                      % For reproducibility

f0      = 100;               % Tone frequency (Hz)
T       = 0.15;               % Duration of simulated spike trains (s)
nCycles = round(f0 * T);     % Number of stimulus cycles
t = linspace(0,T,1000); % s
w = 2*pi*f0;

% Two example sound levels (arbitrary units, e.g. dB SPL)
L_low   = 15;
L_high  = 50;

% Rate–level function parameters (for the "with noise" case)
L_thr   = 0;                % Threshold level
L_max   = 80;                % Level at which rate saturates
R_max   = 200;               % Maximum firing rate (spikes/s)

%% Helper: rate–level function (mean firing rate as function of level)
rate_fun = @(L) max(0, min(R_max, R_max * (L - L_thr) / (L_max - L_thr)));

R_low  = rate_fun(L_low);    % Mean firing rate at low level
R_high = rate_fun(L_high);   % Mean firing rate at high level

% Poisson mean spikes per cycle
lambda_low  = R_low  / f0;
lambda_high = R_high / f0;

% Jitter of spike time around the tone peak (s)
jitter_sd = 0.001;           % 1 ms

%% Function to simulate a Poisson spike train locked to a 100 Hz tone
simulate_spikes = @(lambda_cycle) ...
    simulate_poisson_tone(f0, nCycles, lambda_cycle, jitter_sd);

%% Simulate spike trains for low and high levels
spikeTimes_low  = simulate_spikes(lambda_low);
spikeTimes_high = simulate_spikes(lambda_high);

%% Prepare figure
figure(1);
clf;
%% Panel A: spike train at low level
subplot(221); 
A = 1;
y = A*sin(w*t);
plot(t,y,'k-','LineWidth',2,'Color',col);
hold on;
plot_spike_train(spikeTimes_low,col);
ylabel('Spikes');
xlim([0 T]);
if hasNicegraph
	nicegraph;
	axis normal;
	axis off;
	set(gca, 'FontSize', 24);

else
	% Minimal built-in equivalent
	box off;
	axis off;
	set(gca, 'TickDir', 'out', 'FontSize', 24);
end
if hasBftext
	bf_text(0.05,0.95,'A','FontSize',24);
end
title('low level (just above threshold)','FontSize',16);

%% Panel B: spike train at high level
subplot(223); 
A = 2;
y = A*sin(w*t);
plot(t,y,'k-','LineWidth',2,'Color',col);
hold on;
plot_spike_train(spikeTimes_high,col);
ylabel('Spikes');
xlim([0 T]);
if hasNicegraph
	nicegraph;
	axis normal;
	axis off;
	set(gca, 'FontSize', 24);

else
	% Minimal built-in equivalent
	box off;
	axis off;
	set(gca, 'TickDir', 'out', 'FontSize', 24);
end
if hasBftext
	bf_text(0.05,0.95,'B','FontSize',24);
end
title('higher level','FontSize',16);

%% Panel C: rate–level function with and without noise
subplot(122); hold on;

L_vec = 0:1:L_max;

% Mean rate with noise: smooth linear (clipped) rate–level function
R_with_noise = rate_fun(L_vec);

% Rate without noise: number of *deterministic* spikes per cycle,
% i.e. must be an integer. We idealize as floor(R/f0).
spikes_per_cycle_no_noise = floor(R_with_noise / 50 + eps);
R_no_noise = spikes_per_cycle_no_noise * 50;  % Convert back to rate (steps)

% Plot staircase (no noise)
h(1) = stairs(L_vec, R_no_noise, 'LineWidth', 1.5, 'Color', [0.2 0.2 0.2]);
% Plot smooth line (with noise)
h(2) = plot(L_vec, R_with_noise, 'LineWidth', 2, 'LineStyle','-', 'Color', col);
if hasNicegraph
	nicegraph;
	set(gca, 'FontSize', 24);

else
	% Minimal built-in equivalent
	box off;
	axis square;
	set(gca, 'TickDir', 'out', 'FontSize', 24);
end
xlabel('sound level (arbitrary units)');
ylabel('firing rate (spikes/s)');

if hasNicegraph
    nicegraph;
else
    box off;
    set(gca, 'TickDir','out', 'FontSize', 24);
end

plot(L_low,R_low,'o','LineWidth',2,'Color',col,'MarkerFaceColor',col,'MarkerSize',12);
text(L_low-3,R_low+6,'A','FontSize',24,'HorizontalAlignment','center');

plot(L_high,R_high,'o','LineWidth',2,'Color',col,'MarkerFaceColor',col,'MarkerSize',12);
text(L_high-3,R_high+6,'B','FontSize',24,'HorizontalAlignment','center');

title('rate–level function with and without noise','FontSize',16);
legend(h,{'without noise (staircase)', 'with noise (smooth)'}, ...
       'Location','SE');
if hasBftext
	bf_text(0.05,0.95,'C','FontSize',24);
end


%% Optionally save figure
fname = fullfile('..','images','spike_noise_rate.png');

if hasSavegraph
    savegraph(fname,'png');
else
    try
        exportgraphics(gcf, fname, 'Resolution', 300);
    catch
        print(gcf, fname, '-dpng', '-r300');
    end
end

%%  Local functions 

function spikeTimes = simulate_poisson_tone(f0, nCycles, lambda_cycle, jitter_sd)
%SIMULATE_POISSON_TONE Generate spike times for a tone-locked Poisson process.
%
%   f0           : tone frequency (Hz)
%   nCycles      : number of stimulus cycles
%   lambda_cycle : mean number of spikes per cycle (Poisson)
%   jitter_sd    : standard deviation of temporal jitter (s)
%
%   spikeTimes   : vector of spike times (s)

    spikeTimes = [];
    if lambda_cycle <= 0
        return;
    end

    for k = 1:nCycles
        % Time of the peak of this cycle (center of the cycle)
        t_center = (k - 0.5) / f0;

        % Number of spikes in this cycle
        nSpikes = poissrnd(lambda_cycle);

        if nSpikes > 0
            % Jitter each spike around the cycle center
            jitter = jitter_sd * randn(nSpikes, 1);
            t_spk  = t_center + jitter;

            % (Optional) keep only spikes within the cycle boundaries
            t_start = (k - 1) / f0;
            t_end   = k / f0;
            t_spk   = t_spk(t_spk >= t_start & t_spk <= t_end);

            spikeTimes = [spikeTimes; t_spk(:)];
        end
    end

    spikeTimes = sort(spikeTimes);
end

function plot_spike_train(spikeTimes,col)
%PLOT_SPIKE_TRAIN Simple raster-style plotting of a 1D spike train.

    if isempty(spikeTimes)
        return;
    end

    % Plot as vertical tick marks
    y_bottom = 2.5;
    y_top    = 5.5;
	
	

    for t = spikeTimes(:)'
        line([t t], [y_bottom y_top], 'Color',col, 'LineWidth', 2);
    end
ylim([-2.1 7])
    % set(gca, 'YTick', []);
    % box off;
    % set(gca, 'TickDir','out', 'FontSize', 24);
end