function data = visualsearchexp(varargin)
% VISUALSEARCHEXP_IMPROVED  Visual search (pop-out vs conjunction) RT practical.
%
% This function runs a visual search reaction-time experiment with:
%   - Search type:   pop-out (feature) vs conjunction
%   - Set size:      e.g., [4 8 12 16]
%   - Target:        present vs absent
%
% Participants respond:
%   - Press '1'  -> target PRESENT
%   - Press '0'  -> target ABSENT
%   - Press 'ESC' to abort
%
% Typical stimuli (classic Treisman-style):
%   Pop-out:      target = red 'O', distractors = blue 'X'
%   Conjunction:  target = red 'O', distractors = red 'X' + blue 'O'
%
% Data are saved as a tidy CSV (one row per trial) plus a MAT file.
%
% USAGE
%   data = visualsearchexp_improved();
%   data = visualsearchexp_improved('subject', 1, 'nPerCell', 20);
%   data = visualsearchexp_improved('setSizes', [4 8 12 16], 'fullscreen', true);
%
% NOTES
%   - Does NOT require Psychtoolbox; uses base MATLAB figure callbacks.
%   - For reliable timing, close other heavy apps and avoid remote desktops.
%   - If you have PTB, you can port the drawing to Screen() later.
%
% Marc-style practical suggestion:
%   20 correct trials per cell (setSize x searchType x targetPresence) -> 320 trials
%   You can reduce to 10 per cell for a quicker lab.
%
% (c) 2025. Provided as teaching code.

%% ------------------------- Parameters -------------------------
P = inputParser;
P.addParameter('subject', 1, @(x) isnumeric(x) && isscalar(x));
P.addParameter('setSizes', [4 8 12 16], @(x) isnumeric(x) && all(x>1));
P.addParameter('nPerCell', 10, @(x) isnumeric(x) && isscalar(x) && x>=1);
P.addParameter('maxRT', 2.0, @(x) isnumeric(x) && isscalar(x) && x>0);
P.addParameter('iti', [0.3 0.6], @(x) isnumeric(x) && numel(x)==2 && x(1)>=0 && x(2)>=x(1));
P.addParameter('fixDur', 0.4, @(x) isnumeric(x) && isscalar(x) && x>=0);
P.addParameter('stimDur', Inf, @(x) isnumeric(x) && isscalar(x) && x>0); % Inf = until response
P.addParameter('minDist', 0.06, @(x) isnumeric(x) && isscalar(x) && x>0); % in normalized units
P.addParameter('margin', 0.08, @(x) isnumeric(x) && isscalar(x) && x>=0 && x<0.4);
P.addParameter('fontSize', 28, @(x) isnumeric(x) && isscalar(x) && x>=8);
P.addParameter('fullscreen', true, @(x) islogical(x) || isnumeric(x));
P.addParameter('practiceN', 8, @(x) isnumeric(x) && isscalar(x) && x>=0);
P.addParameter('seed', [], @(x) isempty(x) || (isnumeric(x) && isscalar(x)));
P.addParameter('dataDir', fullfile('..','..','data'), @(x) ischar(x) || isstring(x)); % for Marc
% P.addParameter('dataDir', fullfile('..','data'), @(x) ischar(x) || isstring(x));
P.parse(varargin{:});
prm = P.Results;
prm.fullscreen = logical(prm.fullscreen);

% Ensure even set sizes for clean splitting in conjunction distractors.
prm.setSizes = prm.setSizes(:)';
if any(mod(prm.setSizes,2)~=0)
    warning('Some set sizes are odd. Conjunction distractor split will be approximate.');
end

% RNG
if isempty(prm.seed)
    rng('shuffle');
    prm.seed = rng().Seed;
else
    rng(prm.seed);
end

%% ------------------------- Trial table -------------------------
% Factors:
% searchType: 0 = pop-out, 1 = conjunction
% targetPresent: 0/1
searchTypes = [0 1];
setSizes = prm.setSizes;
targetPres = [0 1];

% Full factorial, repeated nPerCell times
cond = [];
for s = searchTypes
    for n = setSizes
        for tp = targetPres
            cond = [cond; s, n, tp]; %#ok<AGROW>
        end
    end
end
cond = repmat(cond, prm.nPerCell, 1);
% Randomize trial order
cond = cond(randperm(size(cond,1)), :);

nTrials = size(cond,1);

%% ------------------------- Figure setup -------------------------
fig = figure('Color','w', 'MenuBar','none', 'ToolBar','none', 'NumberTitle','off', ...
    'Name', sprintf('Visual Search (Subject %03d)', prm.subject));
ax = axes('Parent',fig);
axis(ax,'off');
axis(ax,'square');
set(ax,'XLim',[0 1],'YLim',[0 1]);

if prm.fullscreen
    set(fig,'Units','normalized','OuterPosition',[0 0 1 1]);
end

% Turn off key-repeat effects
set(fig,'KeyPressFcn',[]);

%% ------------------------- Instructions -------------------------
showCenteredText(ax, {
    'Visual Search Task'
    ''
    'Decide if the target is PRESENT.'
    ''
    "Press '1'  = target PRESENT"
    "Press '0'  = target ABSENT"
    "Press 'ESC' to stop"
    ''
    'Respond as quickly AND accurately as possible.'
    ''
    'Press any key to start practice.'
    }, prm.fontSize);
waitForAnyKey(fig);

%% ------------------------- Practice -------------------------
if prm.practiceN > 0
    practiceCond = cond(1:min(prm.practiceN,nTrials),:);
    runBlock(practiceCond, true);

    showCenteredText(ax, {
        'End of practice.'
        ''
        'Press any key to start the real experiment.'
        }, prm.fontSize);
    waitForAnyKey(fig);
end

%% ------------------------- Experiment -------------------------
[data, aborted] = runBlock(cond, false);

%% ------------------------- Save -------------------------
if ~exist(prm.dataDir,'dir')
    mkdir(prm.dataDir);
end

searchStr = {'popout','conjunction'};
nowStr = char(datetime('now','Format','yyyy-MM-dd_HH-mm-ss'));
base = sprintf('visualsearch_sub-%03d_%s_seed-%d', prm.subject, nowStr, prm.seed);

csvFile = fullfile(prm.dataDir, [base '.csv']);
matFile = fullfile(prm.dataDir, [base '.mat']);

% Add metadata columns
data.subject(:) = prm.subject;
data.seed(:) = prm.seed;
data.aborted(:) = aborted;

writetable(data, csvFile);
save(matFile, 'data', 'prm');

%% ------------------------- Wrap-up -------------------------
if aborted
    msg = {'Experiment aborted.', '', 'Data saved up to the last completed trial.', '', 'Thank you!'};
else
    msg = {'Done!', '', 'Thank you for participating.', '', 'Press any key to close.'};
end
showCenteredText(ax, msg, prm.fontSize);
waitForAnyKey(fig);

if isvalid(fig), close(fig); end

%% ------------------------- Nested runner -------------------------
    function [T, abortedFlag] = runBlock(blockCond, isPractice)
        % Preallocate
        kN = size(blockCond,1);
        T = table('Size',[kN 10], ...
            'VariableTypes', {'double','double','double','double','double','double','double','string','double','double'}, ...
            'VariableNames', {'trial','searchType','setSize','targetPresent','respPresent','correct','rt','respKey','timeout','stimShown'});

        abortedFlag = false;
        for i = 1:kN
            sType = blockCond(i,1);
            nSet  = blockCond(i,2);
            tPres = blockCond(i,3);

            % Inter-trial interval (jitter)
            iti = prm.iti(1) + (prm.iti(2)-prm.iti(1))*rand;
            cla(ax);
            drawnow;
            pause(iti);

            % Fixation
            if prm.fixDur > 0
                cla(ax);
                drawFixation(ax);
                drawnow;
                pause(prm.fixDur);
            end

            % Generate & draw stimulus
            [items, stimOK] = makeStimulus(nSet, logical(tPres), logical(sType));
            cla(ax);
            if stimOK
                drawStimulus(ax, items, prm.fontSize);
            else
                % Fallback message (should be rare)
                showCenteredText(ax, {'Stimulus generation failed.', 'Press any key.'}, prm.fontSize);
            end
            drawnow;

            % Collect response
            [respKey, rt, timedOut, escPressed] = getResponse(fig, prm.maxRT, prm.stimDur);
            if escPressed
                abortedFlag = true;
                % Trim table to completed trials
                T = T(1:i-1,:);
                return
            end

            respPresent = nan;
            if strcmp(respKey,'0'), respPresent = 0; end
            if strcmp(respKey,'1'), respPresent = 1; end

            correct = (~isnan(respPresent)) && (respPresent == tPres) && ~timedOut;

            % Feedback only in practice
            if isPractice
                cla(ax);
                if timedOut
                    showCenteredText(ax, {'Too slow.', 'Respond faster.'}, prm.fontSize);
                elseif isnan(respPresent)
                    showCenteredText(ax, {'Invalid key.', "Use '0' or '1'."}, prm.fontSize);
                elseif correct
                    showCenteredText(ax, {'Correct'}, prm.fontSize);
                else
                    showCenteredText(ax, {'Incorrect'}, prm.fontSize);
                end
                drawnow;
                pause(0.5);
            end

            % Store
            T.trial(i) = i;
            T.searchType(i) = sType;
            T.setSize(i) = nSet;
            T.targetPresent(i) = tPres;
            T.respPresent(i) = respPresent;
            T.correct(i) = double(correct);
            T.rt(i) = rt;
            T.respKey(i) = string(respKey);
            T.timeout(i) = double(timedOut);
            T.stimShown(i) = double(stimOK);
        end
    end

%% ------------------------- Helpers -------------------------
    function [items, ok] = makeStimulus(nSet, targetPresent, isConjunction)
        % Returns struct array with fields: x,y,char,color
        ok = true;
        items = struct('x',{},'y',{},'ch',{},'col',{});
        if nSet < 2
            ok = false;
            return
        end

        % Non-overlapping positions in [margin, 1-margin]
        [x,y,ok] = samplePositions(nSet, prm.margin, prm.minDist, 2000);
        if ~ok
            items = struct('x',{},'y',{},'ch',{},'col',{});
            return
        end

        % Defaults
        blue = [0 0 1];
        red  = [1 0 0];

        if ~isConjunction
            % POP-OUT: distractors = blue X, target = red O
            for k = 1:nSet
                items(k).x = x(k); items(k).y = y(k);
                items(k).ch = 'X';
                items(k).col = blue;
            end
            if targetPresent
                kTar = randi(nSet);
                items(kTar).ch = 'O';
                items(kTar).col = red;
            end
        else
            % CONJUNCTION: distractors = red X + blue O; target = red O
            % Start with half red X, half blue O (approx).
            idx = randperm(nSet);
            nA = floor(nSet/2);
            idxRedX = idx(1:nA);
            idxBlueO = idx(nA+1:end);

            for k = 1:nSet
                items(k).x = x(k); items(k).y = y(k);
                items(k).ch = 'X'; items(k).col = red; % will overwrite some
            end
            for k = idxBlueO
                items(k).ch = 'O';
                items(k).col = blue;
            end

            if targetPresent
                % Replace one random *red X* by red O
                if isempty(idxRedX)
                    kTar = randi(nSet);
                else
                    kTar = idxRedX(randi(numel(idxRedX)));
                end
                items(kTar).ch = 'O';
                items(kTar).col = red;
            end
        end
    end

    function drawStimulus(axh, items, fsz)
        % Draw all items
        for k = 1:numel(items)
            text(axh, items(k).x, items(k).y, items(k).ch, ...
                'Color', items(k).col, 'FontSize', fsz, ...
                'HorizontalAlignment','center', 'VerticalAlignment','middle');
        end
    end

    function drawFixation(axh)
        text(axh, 0.5, 0.5, '+', 'Color',[0 0 0], 'FontSize', prm.fontSize+6, ...
            'HorizontalAlignment','center', 'VerticalAlignment','middle');
    end

    function showCenteredText(axh, lines, fsz)
        cla(axh);
        if ischar(lines) || isstring(lines)
            lines = cellstr(lines);
        end
        nL = numel(lines);
        y0 = 0.55;
        dy = 0.06;
        for ii = 1:nL
            text(axh, 0.5, y0 - (ii-1)*dy, lines{ii}, 'Color',[0 0 0], ...
                'FontSize', fsz, 'HorizontalAlignment','center', 'VerticalAlignment','middle');
        end
        axis(axh,'off');
        axis(axh,'square');
        set(axh,'XLim',[0 1],'YLim',[0 1]);
        drawnow;
    end

    function waitForAnyKey(figHandle)
        setappdata(figHandle,'key','');
        set(figHandle,'KeyPressFcn',@(src,evt) setappdata(src,'key',evt.Key));
        while true
            pause(0.01);
            k = getappdata(figHandle,'key');
            if ~isempty(k)
                break
            end
        end
        set(figHandle,'KeyPressFcn',[]);
    end

    function [respKey, rt, timedOut, escPressed] = getResponse(figHandle, maxRT, stimDur)
        % Wait for '0' or '1' (or ESC) up to maxRT.
        % stimDur allows you to hide stimulus after a fixed time; Inf keeps it.
        setappdata(figHandle,'key','');
        setappdata(figHandle,'char','');
        set(figHandle,'KeyPressFcn',@onKey);

        t0 = tic;
        timedOut = false;
        escPressed = false;
        respKey = '';

        stimHidden = false;
        while true
            pause(0.001);
            t = toc(t0);

            % Optional hide stimulus after stimDur
            if ~stimHidden && isfinite(stimDur) && t >= stimDur
                cla(ax);
                drawnow;
                stimHidden = true;
            end

            k = getappdata(figHandle,'char');
            if ~isempty(k)
                if strcmp(k, char(27)) % ESC
                    escPressed = true;
                    rt = t;
                    break
                end
                if any(strcmp(k, {'0','1'}))
                    respKey = k;
                    rt = t;
                    break
                else
                    % ignore other keys, keep waiting
                    setappdata(figHandle,'char','');
                end
            end

            if t >= maxRT
                timedOut = true;
                rt = maxRT;
                break
            end

            if ~isvalid(figHandle)
                timedOut = true;
                rt = t;
                break
            end
        end

        set(figHandle,'KeyPressFcn',[]);

        function onKey(src,evt)
            % evt.Key is platform dependent; use evt.Character for '0'/'1'
            if isfield(evt,'Character')
                setappdata(src,'char',evt.Character);
            else
                % Fallback: map digits
                if strcmp(evt.Key,'0') || strcmp(evt.Key,'numpad0'), setappdata(src,'char','0'); end
                if strcmp(evt.Key,'1') || strcmp(evt.Key,'numpad1'), setappdata(src,'char','1'); end
                if strcmp(evt.Key,'escape'), setappdata(src,'char',char(27)); end
            end
            setappdata(src,'key',evt.Key);
        end
    end

    function [x,y,ok] = samplePositions(nPts, margin, minDist, maxIter)
        % Rejection sampling for non-overlapping points.
        ok = true;
        x = nan(nPts,1); y = nan(nPts,1);
        nPlaced = 0;
        iter = 0;
        while nPlaced < nPts
            iter = iter + 1;
            if iter > maxIter
                ok = false;
                return
            end
            cand = margin + (1-2*margin)*rand(1,2);
            if nPlaced == 0
                nPlaced = 1;
                x(nPlaced) = cand(1);
                y(nPlaced) = cand(2);
            else
                d = hypot(x(1:nPlaced)-cand(1), y(1:nPlaced)-cand(2));
                if all(d >= minDist)
                    nPlaced = nPlaced + 1;
                    x(nPlaced) = cand(1);
                    y(nPlaced) = cand(2);
                end
            end
        end
    end

end
