classdef BA %BA is a Bayesian ARTMAP object that can be trained and used to classify % % To construct a BA class use the following constructor: % obj = BA (smax,pmin,type,covarType) % type: 'ba' / 'sba' / 'fix_smax' % covarType: 'full' / 'diagonal' / 'equal' % % All arguments are optional and defaults are taken from defaults.mat % % The following methods are available: % % obj = setFeatures (obj, featuresToUse) % Set which features are used. Can only run before training. % % obj = train (obj, patterns, labels) % Train model on a list of pattens and corresponding lables. % % classPosterior = val (obj, patterns) % Evaluate class-posterior-probabilities for each class (columns) % and each pattern (rows). % % classification = classify (obj, patterns) % Classify each of a list of patterns. % % inference = test (obj, patterns, labels) % Evaluate inference by using the test / validation patterns % provided. % Code written by Noam Nelke % Under the supervision of Dr. Boaz Lerner % Based in part on the work of Boaz Vigdor and Saar Abramowitz properties (GetAccess = 'public', SetAccess = 'private') type; covarType; nclusters = 0; means = []; priors = []; covars = []; map = []; smax; pmin; nfeatures; featuresToUse; initCovar; end methods function obj = BA (smax, pmin, type, covarType) if nargin < 4 default = load('defaults.mat'); obj.smax = default.smax; obj.pmin = default.pmin; obj.type = lower(default.type); obj.covarType = default.covarType; end if nargin >= 1 obj.smax = smax; end if nargin >= 2 obj.pmin = pmin; end if nargin >= 3 obj.type = lower(type); end if nargin == 4 obj.covarType = lower(covarType); end end function obj = setFeatures (obj, featuresToUse) % obj = setFeatures (obj, featuresToUse) % Set which features are used. Can only run before training. if obj.nclusters == 0 obj.featuresToUse = featuresToUse; else error('Cannot change feature space after training has begun!'); end end function obj = train (obj, patterns, labels) % obj = train (obj, patterns, labels) % Train model on a list of pattens and corresponding lables. patterns = obj.fixPatterns(patterns); if obj.nclusters == 0 obj = obj.firstPattern(patterns(1,:), labels(1)); patterns = patterns(2:end,obj.featuresToUse); labels = labels(2:end); else patterns = patterns(:,obj.featuresToUse); end while ~isempty(patterns) obj = obj.learnPattern(patterns(1,:), labels(1)); patterns = patterns(2:end,:); labels = labels(2:end); end end function classPosterior = val (obj, patterns) % classPosterior = val (obj, patterns) % Evaluate class-posterior-probabilities for each class (columns) % and each pattern (rows). patterns = obj.fixPatterns(patterns); patterns = patterns(:,obj.featuresToUse); clusterPosterior = nan(size(patterns,1),obj.nclusters); for cluster = 1:obj.nclusters clusterPosterior(:,cluster) = BA.gaussDensity(obj.means(cluster,:),obj.covars(:,:,cluster),patterns); end classPosterior = clusterPosterior*obj.map; classPosterior = classPosterior./(sum(classPosterior,2)*ones(1,size(classPosterior,2))); end function [classification classPosterior] = classify (obj, patterns) % classification = classify (obj, patterns) % Classify each of a list of patterns. classPosterior = obj.val(patterns); [junk classification] = max(classPosterior,[],2); end function [inference classification classPosterior] = test (obj, patterns, labels) % inference = test (obj, patterns, labels) % Evaluate inference by using the test / validation patterns % provided. [classification classPosterior] = obj.classify(patterns); inference = sum(classification == labels(:,1:size(classification,2))) / length(labels); end end methods (Static) function density = gaussDensity (mu, covar, patterns) % density = gaussDensity (mu, covar, patterns) % Returns the density of a given Gaussian at several points [m n] = size(patterns); patterns = patterns - ones(m, 1)*mu; fact = sum(((patterns/covar).*patterns), 2); density = exp(-0.5*fact)./sqrt((2*pi)^n*det(covar)); end end methods (Access = 'private') function obj = learnPattern (obj, pattern, label) patternsPerCluster = sum(obj.map,2); logPosterior = nan(1,obj.nclusters); for cluster = 1:obj.nclusters d = pattern - obj.means(cluster,:); s = obj.covars(:,:,cluster); logPosterior(cluster) = log(obj.priors(cluster)) + log(det(s)^(-.5)) + (-.5*d*s^(-1)*d'); end candidateClusters = 1:obj.nclusters; currentSmax = obj.smax; artmapComplete = false; while ~artmapComplete clusterChosen = false; while ~clusterChosen [maxLogPosterior bestIndex] = max(logPosterior(candidateClusters)); bestCluster = candidateClusters(bestIndex); newMeans = (patternsPerCluster(bestCluster)*obj.means(bestCluster,:)+pattern)/(patternsPerCluster(bestCluster)+1); switch obj.covarType case 'full' newCovars = (patternsPerCluster(bestCluster)*obj.covars(:,:,bestCluster)+(pattern-newMeans)'*(pattern-newMeans))/(patternsPerCluster(bestCluster)+1); case 'diagonal' newCovars = (patternsPerCluster(bestCluster)*obj.covars(:,:,bestCluster)+(pattern-newMeans)'*(pattern-newMeans).*eye(obj.nfeatures))/(patternsPerCluster(bestCluster)+1); case 'equal' newCovars = (patternsPerCluster(bestCluster)*obj.covars(:,:,bestCluster)+mean((pattern-newMeans).^2).*eye(obj.nfeatures))/(patternsPerCluster(bestCluster)+1); end if (det(newCovars)<=currentSmax) clusterChosen = true; else candidateClusters(bestIndex) = []; if isempty(candidateClusters) obj = obj.addCluster(pattern, label); artmapComplete = true; break; end end end if artmapComplete break; end newMap = obj.map; if label>size(newMap,2) newMap(bestCluster,label) = 0; end switch obj.type case {'ba','fix_smax'} newMap(bestCluster,label) = newMap(bestCluster,label)+1; case 'sba' newMap(bestCluster,:) = newMap(bestCluster,:) + obj.mapDelta(newMap,newMeans,newCovars,pattern,bestCluster); end if newMap(bestCluster,label)/sum(newMap(bestCluster,:)) >= obj.pmin obj.means(bestCluster,:) = newMeans; obj.covars(:,:,bestCluster) = newCovars; obj.map = newMap; break; else if any(strcmp(obj.type,{'ba' 'sba'})) currentSmax = det(newCovars)-eps(det(newCovars)); end candidateClusters(bestIndex) = []; if isempty(candidateClusters) obj = obj.addCluster(pattern, label); break; end end end obj.priors = sum(obj.map,2) / sum(sum(obj.map)); end function obj = firstPattern (obj, pattern, label) if isempty(obj.featuresToUse) obj.nfeatures = length(pattern); obj.featuresToUse = 1:obj.nfeatures; else obj.featuresToUse(obj.featuresToUse > length(pattern)) = []; % Remove selected features that don't exist obj.nfeatures = length(obj.featuresToUse); end obj.initCovar = eye(obj.nfeatures) .* obj.smax/100; obj.covars = nan(obj.nfeatures,obj.nfeatures,0); obj = obj.addCluster(pattern, label); obj.priors = sum(obj.map,2) / sum(sum(obj.map)); end function obj = addCluster (obj, pattern, label) obj.nclusters = obj.nclusters + 1; obj.means(obj.nclusters,:) = pattern; obj.covars(:,:,obj.nclusters) = obj.initCovar; switch obj.type case {'ba','fix_smax'} obj.map(obj.nclusters, label) = 1; case 'sba' obj.map(obj.nclusters, :) = (1-obj.pmin)*ones(obj.nfeatures,1)/2; obj.map(obj.nclusters, label) = (1+obj.pmin)/2; end end function posterior = mapDelta (obj,newMap,inputMeans,inputCovars,pattern,bestCluster) newMeans = obj.means; newMeans(bestCluster,:) = inputMeans; newCovars = obj.covars; newCovars(:,:,bestCluster) = inputCovars; Px_given_W = nan(obj.nclusters,1); for cluster = 1:obj.nclusters mu = newMeans(cluster,:); sigma = newCovars(:,:,cluster); Px_given_W(cluster) = log((det(sigma)^(-.5))*((2*pi)^(-obj.nfeatures/2)))-.5*(pattern-mu)*(sigma^(-1))*(pattern-mu)'; end Px_given_W = exp(Px_given_W + log(sum(newMap,2))); num = nan(1,size(newMap,2)); for class = 1:size(newMap,2) num(class) = newMap(:,class)'*Px_given_W; end denum = sum(num); if denum == 0 for class = 1:size(newMap,2) num(class) = sum(log(newMap(:,class))+log(Px_given_W)); end num = exp(num-num(1)); denum = sum(num); end posterior = num/denum; end function patterns = fixPatterns (obj,patterns) if ndims(patterns) > 2 patterns = reshape(patterns,size(patterns,1),[]); end if obj.nclusters ~= 0 && size(patterns,2) ~= obj.nfeatures error('The number of features must remain constant.\nYou provided %g instead of %g.',size(patterns,2),obj.nfeatures); end end end end