@stdlib/ml-incr-binary-classification
Version:
Incrementally perform binary classification using stochastic gradient descent (SGD).
20 lines (19 loc) • 9.33 kB
JavaScript
"use strict";var _=function(i,e){return function(){return e||i((e={exports:{}}).exports,e),e.exports}};var S=_(function(ye,F){"use strict";var u=require("@stdlib/utils-define-nonenumerable-read-only-property"),T=require("@stdlib/utils-define-nonenumerable-read-only-accessor"),G=require("@stdlib/string-format"),D=require("@stdlib/blas-base-gdot").ndarray,B=require("@stdlib/blas-base-gaxpy").ndarray,U=require("@stdlib/blas-base-dcopy"),N=require("@stdlib/blas-base-dscal"),Y=require("@stdlib/math-base-special-max"),J=require("@stdlib/math-base-special-exp"),K=require("@stdlib/math-base-special-pow"),Q=require("@stdlib/math-base-special-expit"),w=require("@stdlib/array-float64"),O=require("@stdlib/ndarray-ctor"),W=require("@stdlib/ndarray-base-shape2strides"),Z=require("@stdlib/ndarray-base-numel"),$=require("@stdlib/ndarray-base-vind2bind"),X=1e-7,x=1e-11,ee={basic:"_basicLearningRate",constant:"_constantLearningRate",invscaling:"_inverseScalingLearningRate",pegasos:"_pegasosLearningRate"},re={hinge:"_hingeLoss",log:"_logLoss",modifiedHuber:"_modifiedHuberLoss",perceptron:"_perceptronLoss",squaredHinge:"_squaredHingeLoss"};function n(i,e){var r;return this._N=i,this._opts=e,this._scaleFactor=1,this._t=0,this._learningRateMethod=ee[e.learningRate[0]],this._lossMethod=re[e.loss],r=i,e.intercept&&(r+=1),this._weights=new w(r),this._coefficients=new O("float64",new w(r),[r],[1],0,"row-major"),this}u(n.prototype,"_add",function(e,r){var t=r/this._scaleFactor,a=this._weights;return B(e.shape[0],t,e.data,e.strides[0],e.offset,a,1,0),this._opts.intercept&&(a[this._N]+=t),this});u(n.prototype,"_basicLearningRate",function(){return 10/(10+this._t)});u(n.prototype,"_constantLearningRate",function(){return this._opts.learningRate[1]});u(n.prototype,"_dot",function(e,r,t){var a=D(this._N,this._weights,1,0,e,r,t);return this._opts.intercept&&(a+=this._weights[this._N]),a*=this._scaleFactor,a});u(n.prototype,"_hingeLoss",function(e,r){var t,a;return t=this[this._learningRateMethod](),this._regularize(t),a=this._dot(e.data,e.strides[0],e.offset),r*a<1&&this._add(e,r*t),this});u(n.prototype,"_inverseScalingLearningRate",function(){var e=this._opts.learningRate;return e[1]/K(this._t,e[2])});u(n.prototype,"_logLoss",function(e,r){var t,a,l;return a=this[this._learningRateMethod](),this._regularize(a),l=this._dot(e.data,e.strides[0],e.offset),t=r/(1+J(r*l)),this._add(e,a*t),this});u(n.prototype,"_modifiedHuberLoss",function(e,r){var t,a;return t=this[this._learningRateMethod](),this._regularize(t),a=r*this._dot(e.data,e.strides[0],e.offset),a<-1?this._add(e,4*t*r):this._add(e,t*(r-a*r)),this});u(n.prototype,"_pegasosLearningRate",function(){return 1/(this._opts.lambda*this._t)});u(n.prototype,"_perceptronLoss",function(e,r){var t,a;return t=this[this._learningRateMethod](),this._regularize(t),a=this._dot(e.data,e.strides[0],e.offset),r*a<=0&&this._add(e,r*t),this});u(n.prototype,"_regularize",function(e){var r=this._opts.lambda;return r<=0?this:(this._scale(Y(1-e*r,X)),this)});u(n.prototype,"_scale",function(e){var r;if(e<=0)throw new RangeError(G("invalid argument. Attempting to scale a weight vector by a nonpositive value. This is likely due to too large a value of eta * lambda. Value: `%f`.",e));return r=this._scaleFactor,r<x&&(N(this._N,r,this._weights,1),this._scaleFactor=1),this._scaleFactor*=e,this});u(n.prototype,"_squaredHingeLoss",function(e,r){var t,a;return t=this[this._learningRateMethod](),this._regularize(t),a=r*this._dot(e.data,e.strides[0],e.offset),a<1&&this._add(e,t*(r-a*r)),this});T(n.prototype,"coefficients",function(){var e=this._coefficients.data,r=this._weights;return U(r.length,r,1,e,1),N(this._N,this._scaleFactor,e,1),this._coefficients});T(n.prototype,"nfeatures",function(){return this._N});u(n.prototype,"predict",function(e,r){var t,a,l,f,s,o,h,p,y,q,L,m,E,b,v,d;for(a=e.data,f=e.shape,y=e.strides,L=e.offset,o=e.order,t=f.length-1,s=[],d=0;d<t;d++)s.push(f[d]);for(t===0?(m=1,l=new w(1),q=[0]):(m=Z(s),l=new w(m),q=W(s,o)),b=new O("int8",l,s,q,0,o),E=this._N,p=y[t],d=0;d<m;d++)h=$(f,y,L,o,d*E,"throw"),v=this._dot(a,p,h),r==="label"?v=v>0?1:-1:r==="probability"&&(v=Q(v)),t===0?b.iset(v):b.iset(d,v);return b});u(n.prototype,"update",function(e,r){return this._t+=1,this[this._lossMethod](e,r)});F.exports=n});var A=_(function(qe,ie){ie.exports={basic:["basic"],constant:["constant",.02],invscaling:["invscaling",.02,.5],pegasos:["pegasos"]}});var M=_(function(Le,te){te.exports=["basic","constant","invscaling","pegasos"]});var H=_(function(Ee,ae){ae.exports=["hinge","log","modifiedHuber","perceptron","squaredHinge"]});var V=_(function(Te,I){"use strict";var ne=require("@stdlib/assert-is-nonnegative-number").isPrimitive,se=require("@stdlib/assert-is-positive-number").isPrimitive,oe=require("@stdlib/assert-is-number").isPrimitive,ue=require("@stdlib/assert-is-boolean").isPrimitive,le=require("@stdlib/assert-is-array-like-object"),de=require("@stdlib/assert-is-plain-object"),R=require("@stdlib/assert-has-own-property"),z=require("@stdlib/assert-contains"),g=require("@stdlib/string-format"),P=M(),j=H();function he(i,e){var r;if(!de(e))return new TypeError(g("invalid argument. Options argument must be an object. Value: `%s`.",e));if(R(e,"intercept")&&(i.intercept=e.intercept,!ue(i.intercept)))return new TypeError(g("invalid option. `%s` option must be a boolean. Option: `%s`.","intercept",i.intercept));if(R(e,"lambda")&&(i.lambda=e.lambda,!ne(i.lambda)))return new TypeError(g("invalid option. `%s` option must be a nonnegative number. Option: `%s`.","lambda",i.lambda));if(R(e,"learningRate")){if(!le(e.learningRate))return new TypeError(g("invalid option. `%s` option must be an array-like object. Option: `%s`.","learningRate",e.learningRate));if(r=e.learningRate[0],i.learningRate[0]=r,!z(P,r))return new TypeError(g('invalid option. First `%s` option must be one of the following: "%s". Option: `%s`.',"learningRate",P.join('", "'),r));if(e.learningRate.length>1&&(r==="constant"||r==="invscaling")&&(i.learningRate[1]=e.learningRate[1],!se(i.learningRate[1])))return new TypeError(g("invalid option. Second `%s` option must be a positive number. Option: `%s`.","learningRate",i.learningRate[1]));if(e.learningRate.length>2&&r==="invscaling"&&(i.learningRate[2]=e.learningRate[2],!oe(i.learningRate[2])))return new TypeError(g("invalid option. Third `%s` option must be a number. Option: `%s`.","learningRate",i.learningRate[2]))}return R(e,"loss")&&(i.loss=e.loss,!z(j,i.loss))?new TypeError(g('invalid option. `%s` option must be one of the following: "%s". Option: `%s`.',"loss",j.join('", "'),i.loss)):null}I.exports=he});var C=_(function(Ne,k){"use strict";var ge=require("@stdlib/assert-is-positive-integer").isPrimitive,ce=require("@stdlib/assert-is-vector-like"),ve=require("@stdlib/assert-is-ndarray-like"),fe=require("@stdlib/utils-define-nonenumerable-read-only-property"),c=require("@stdlib/string-format"),pe=S(),_e=A(),me=V();function be(i,e){var r,t,a;if(!ge(i))throw new TypeError(c("invalid argument. First argument must be a positive integer. Value: `%s`.",i));if(t={intercept:!0,lambda:1e-4,learningRate:_e.basic.slice(),loss:"log"},arguments.length>1&&(a=me(t,e),a))throw a;return r=new pe(i,t),fe(l,"predict",f),l;function l(s,o){if(arguments.length===0)return r.coefficients;if(!ce(s))throw new TypeError(c("invalid argument. First argument must be a one-dimensional ndarray. Value: `%s`.",s));if(o!==-1&&o!==1)throw new TypeError(c("invalid argument. Second argument must be either +1 or -1. Value: `%s`.",o));if(s.shape[0]!==r.nfeatures)throw new TypeError(c("invalid argument. First argument must be a one-dimensional ndarray of length %u. Actual length: `%u`.",r.nfeatures,s.shape[0]));return r.update(s,o),r.coefficients}function f(s,o){var h,p;if(!ve(s))throw new TypeError(c("invalid argument. First argument must be an ndarray. Value: `%s`.",s));if(h=s.shape,h[h.length-1]!==i)throw new TypeError(c("invalid argument. First argument must be an ndarray whose last dimension is of size %u. Actual size: `%u`.",i,h[h.length-1]));if(p="label",arguments.length>1){if(o==="probability"){if(t.loss!=="log"&&t.loss!=="modifiedHuber")throw new Error(c('invalid argument. Second argument is incompatible with model loss function. Probability predictions are only supported when the loss function is one of the following: "%s". Model loss function: `%s`.',["log","modifiedHuber"].join('", "'),t.loss))}else if(o!=="label"&&o!=="linear")throw new TypeError(c('invalid argument. Second argument must be a string value equal to either "label", "probability", or "linear". Value: `%s`.',o));p=o}return r.predict(s,p)}}k.exports=be});var we=C();module.exports=we;
/**
* @license Apache-2.0
*
* Copyright (c) 2018 The Stdlib Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//# sourceMappingURL=index.js.map