ndarray-sort
Version:
Sorts ndarrays in place
728 lines (663 loc) • 19.9 kB
JavaScript
"use strict"
var pool = require("typedarray-pool")
var INSERTION_SORT_THRESHOLD = 32
function getMallocFree(dtype) {
switch(dtype) {
case "uint8":
return [pool.mallocUint8, pool.freeUint8]
case "uint16":
return [pool.mallocUint16, pool.freeUint16]
case "uint32":
return [pool.mallocUint32, pool.freeUint32]
case "int8":
return [pool.mallocInt8, pool.freeInt8]
case "int16":
return [pool.mallocInt16, pool.freeInt16]
case "int32":
return [pool.mallocInt32, pool.freeInt32]
case "float32":
return [pool.mallocFloat, pool.freeFloat]
case "float64":
return [pool.mallocDouble, pool.freeDouble]
default:
return null
}
}
function shapeArgs(dimension) {
var args = []
for(var i=0; i<dimension; ++i) {
args.push("s"+i)
}
for(var i=0; i<dimension; ++i) {
args.push("n"+i)
}
for(var i=1; i<dimension; ++i) {
args.push("d"+i)
}
for(var i=1; i<dimension; ++i) {
args.push("e"+i)
}
for(var i=1; i<dimension; ++i) {
args.push("f"+i)
}
return args
}
function createInsertionSort(order, dtype) {
var code = ["'use strict'"]
var funcName = ["ndarrayInsertionSort", order.join("d"), dtype].join("")
var funcArgs = ["left", "right", "data", "offset" ].concat(shapeArgs(order.length))
var allocator = getMallocFree(dtype)
var vars = [ "i,j,cptr,ptr=left*s0+offset" ]
if(order.length > 1) {
var scratch_shape = []
for(var i=1; i<order.length; ++i) {
vars.push("i"+i)
scratch_shape.push("n"+i)
}
if(allocator) {
vars.push("scratch=malloc(" + scratch_shape.join("*") + ")")
} else {
vars.push("scratch=new Array("+scratch_shape.join("*") + ")")
}
vars.push("dptr","sptr","a","b")
} else {
vars.push("scratch")
}
function dataRead(ptr) {
if(dtype === "generic") {
return ["data.get(", ptr, ")"].join("")
}
return ["data[",ptr,"]"].join("")
}
function dataWrite(ptr, v) {
if(dtype === "generic") {
return ["data.set(", ptr, ",", v, ")"].join("")
}
return ["data[",ptr,"]=",v].join("")
}
//Create function header
code.push(
["function ", funcName, "(", funcArgs.join(","), "){var ", vars.join(",")].join(""),
"for(i=left+1;i<=right;++i){",
"j=i;ptr+=s0",
"cptr=ptr")
if(order.length > 1) {
//Copy data into scratch
code.push("dptr=0;sptr=ptr")
for(var i=order.length-1; i>=0; --i) {
var j = order[i]
if(j === 0) {
continue
}
code.push(["for(i",j,"=0;i",j,"<n",j,";++i",j,"){"].join(""))
}
code.push("scratch[dptr++]=",dataRead("sptr"))
for(var i=0; i<order.length; ++i) {
var j = order[i]
if(j === 0) {
continue
}
code.push("sptr+=d"+j,"}")
}
//Compare items in outer loop
code.push("__g:while(j-->left){",
"dptr=0",
"sptr=cptr-s0")
for(var i=1; i<order.length; ++i) {
if(i === 1) {
code.push("__l:")
}
code.push(["for(i",i,"=0;i",i,"<n",i,";++i",i,"){"].join(""))
}
code.push(["a=", dataRead("sptr"),"\nb=scratch[dptr]\nif(a<b){break __g}\nif(a>b){break __l}"].join(""))
for(var i=order.length-1; i>=1; --i) {
code.push(
"sptr+=e"+i,
"dptr+=f"+i,
"}")
}
//Copy data back
code.push("dptr=cptr;sptr=cptr-s0")
for(var i=order.length-1; i>=0; --i) {
var j = order[i]
if(j === 0) {
continue
}
code.push(["for(i",j,"=0;i",j,"<n",j,";++i",j,"){"].join(""))
}
code.push(dataWrite("dptr", dataRead("sptr")))
for(var i=0; i<order.length; ++i) {
var j = order[i]
if(j === 0) {
continue
}
code.push(["dptr+=d",j,";sptr+=d",j].join(""),"}")
}
//Close while loop
code.push("cptr-=s0\n}")
//Copy scratch into cptr
code.push("dptr=cptr;sptr=0")
for(var i=order.length-1; i>=0; --i) {
var j = order[i]
if(j === 0) {
continue
}
code.push(["for(i",j,"=0;i",j,"<n",j,";++i",j,"){"].join(""))
}
code.push(dataWrite("dptr", "scratch[sptr++]"))
for(var i=0; i<order.length; ++i) {
var j = order[i]
if(j === 0) {
continue
}
code.push("dptr+=d"+j,"}")
}
} else {
code.push("scratch=" + dataRead("ptr"),
"while((j-->left)&&("+dataRead("cptr-s0")+">scratch)){",
dataWrite("cptr", dataRead("cptr-s0")),
"cptr-=s0",
"}",
dataWrite("cptr", "scratch"))
}
//Close outer loop body
code.push("}")
if(order.length > 1 && allocator) {
code.push("free(scratch)")
}
code.push("} return " + funcName)
//Compile and link function
if(allocator) {
var result = new Function("malloc", "free", code.join("\n"))
return result(allocator[0], allocator[1])
} else {
var result = new Function(code.join("\n"))
return result()
}
}
function createQuickSort(order, dtype, insertionSort) {
var code = [ "'use strict'" ]
var funcName = ["ndarrayQuickSort", order.join("d"), dtype].join("")
var funcArgs = ["left", "right", "data", "offset" ].concat(shapeArgs(order.length))
var allocator = getMallocFree(dtype)
var labelCounter=0
code.push(["function ", funcName, "(", funcArgs.join(","), "){"].join(""))
var vars = [
"sixth=((right-left+1)/6)|0",
"index1=left+sixth",
"index5=right-sixth",
"index3=(left+right)>>1",
"index2=index3-sixth",
"index4=index3+sixth",
"el1=index1",
"el2=index2",
"el3=index3",
"el4=index4",
"el5=index5",
"less=left+1",
"great=right-1",
"pivots_are_equal=true",
"tmp",
"tmp0",
"x",
"y",
"z",
"k",
"ptr0",
"ptr1",
"ptr2",
"comp_pivot1=0",
"comp_pivot2=0",
"comp=0"
]
if(order.length > 1) {
var ele_size = []
for(var i=1; i<order.length; ++i) {
ele_size.push("n"+i)
vars.push("i"+i)
}
for(var i=0; i<8; ++i) {
vars.push("b_ptr"+i)
}
vars.push(
"ptr3",
"ptr4",
"ptr5",
"ptr6",
"ptr7",
"pivot_ptr",
"ptr_shift",
"elementSize="+ele_size.join("*"))
if(allocator) {
vars.push("pivot1=malloc(elementSize)",
"pivot2=malloc(elementSize)")
} else {
vars.push("pivot1=new Array(elementSize),pivot2=new Array(elementSize)")
}
} else {
vars.push("pivot1", "pivot2")
}
//Initialize local variables
code.push("var " + vars.join(","))
function toPointer(v) {
return ["(offset+",v,"*s0)"].join("")
}
function dataRead(ptr) {
if(dtype === "generic") {
return ["data.get(", ptr, ")"].join("")
}
return ["data[",ptr,"]"].join("")
}
function dataWrite(ptr, v) {
if(dtype === "generic") {
return ["data.set(", ptr, ",", v, ")"].join("")
}
return ["data[",ptr,"]=",v].join("")
}
function cacheLoop(ptrs, usePivot, body) {
if(ptrs.length === 1) {
code.push("ptr0="+toPointer(ptrs[0]))
} else {
for(var i=0; i<ptrs.length; ++i) {
code.push(["b_ptr",i,"=s0*",ptrs[i]].join(""))
}
}
if(usePivot) {
code.push("pivot_ptr=0")
}
code.push("ptr_shift=offset")
for(var i=order.length-1; i>=0; --i) {
var j = order[i]
if(j === 0) {
continue
}
code.push(["for(i",j,"=0;i",j,"<n",j,";++i",j,"){"].join(""))
}
if(ptrs.length > 1) {
for(var i=0; i<ptrs.length; ++i) {
code.push(["ptr",i,"=b_ptr",i,"+ptr_shift"].join(""))
}
}
code.push(body)
if(usePivot) {
code.push("++pivot_ptr")
}
for(var i=0; i<order.length; ++i) {
var j = order[i]
if(j === 0) {
continue
}
if(ptrs.length>1) {
code.push("ptr_shift+=d"+j)
} else {
code.push("ptr0+=d"+j)
}
code.push("}")
}
}
function lexicoLoop(label, ptrs, usePivot, body) {
if(ptrs.length === 1) {
code.push("ptr0="+toPointer(ptrs[0]))
} else {
for(var i=0; i<ptrs.length; ++i) {
code.push(["b_ptr",i,"=s0*",ptrs[i]].join(""))
}
code.push("ptr_shift=offset")
}
if(usePivot) {
code.push("pivot_ptr=0")
}
if(label) {
code.push(label+":")
}
for(var i=1; i<order.length; ++i) {
code.push(["for(i",i,"=0;i",i,"<n",i,";++i",i,"){"].join(""))
}
if(ptrs.length > 1) {
for(var i=0; i<ptrs.length; ++i) {
code.push(["ptr",i,"=b_ptr",i,"+ptr_shift"].join(""))
}
}
code.push(body)
for(var i=order.length-1; i>=1; --i) {
if(usePivot) {
code.push("pivot_ptr+=f"+i)
}
if(ptrs.length > 1) {
code.push("ptr_shift+=e"+i)
} else {
code.push("ptr0+=e"+i)
}
code.push("}")
}
}
function cleanUp() {
if(order.length > 1 && allocator) {
code.push("free(pivot1)", "free(pivot2)")
}
}
function compareSwap(a_id, b_id) {
var a = "el"+a_id
var b = "el"+b_id
if(order.length > 1) {
var lbl = "__l" + (++labelCounter)
lexicoLoop(lbl, [a, b], false, [
"comp=",dataRead("ptr0"),"-",dataRead("ptr1"),"\n",
"if(comp>0){tmp0=", a, ";",a,"=",b,";", b,"=tmp0;break ", lbl,"}\n",
"if(comp<0){break ", lbl, "}"
].join(""))
} else {
code.push(["if(", dataRead(toPointer(a)), ">", dataRead(toPointer(b)), "){tmp0=", a, ";",a,"=",b,";", b,"=tmp0}"].join(""))
}
}
compareSwap(1, 2)
compareSwap(4, 5)
compareSwap(1, 3)
compareSwap(2, 3)
compareSwap(1, 4)
compareSwap(3, 4)
compareSwap(2, 5)
compareSwap(2, 3)
compareSwap(4, 5)
if(order.length > 1) {
cacheLoop(["el1", "el2", "el3", "el4", "el5", "index1", "index3", "index5"], true, [
"pivot1[pivot_ptr]=",dataRead("ptr1"),"\n",
"pivot2[pivot_ptr]=",dataRead("ptr3"),"\n",
"pivots_are_equal=pivots_are_equal&&(pivot1[pivot_ptr]===pivot2[pivot_ptr])\n",
"x=",dataRead("ptr0"),"\n",
"y=",dataRead("ptr2"),"\n",
"z=",dataRead("ptr4"),"\n",
dataWrite("ptr5", "x"),"\n",
dataWrite("ptr6", "y"),"\n",
dataWrite("ptr7", "z")
].join(""))
} else {
code.push([
"pivot1=", dataRead(toPointer("el2")), "\n",
"pivot2=", dataRead(toPointer("el4")), "\n",
"pivots_are_equal=pivot1===pivot2\n",
"x=", dataRead(toPointer("el1")), "\n",
"y=", dataRead(toPointer("el3")), "\n",
"z=", dataRead(toPointer("el5")), "\n",
dataWrite(toPointer("index1"), "x"), "\n",
dataWrite(toPointer("index3"), "y"), "\n",
dataWrite(toPointer("index5"), "z")
].join(""))
}
function moveElement(dst, src) {
if(order.length > 1) {
cacheLoop([dst, src], false,
dataWrite("ptr0", dataRead("ptr1"))
)
} else {
code.push(dataWrite(toPointer(dst), dataRead(toPointer(src))))
}
}
moveElement("index2", "left")
moveElement("index4", "right")
function comparePivot(result, ptr, n) {
if(order.length > 1) {
var lbl = "__l" + (++labelCounter)
lexicoLoop(lbl, [ptr], true, [
result,"=",dataRead("ptr0"),"-pivot",n,"[pivot_ptr]\n",
"if(",result,"!==0){break ", lbl, "}"
].join(""))
} else {
code.push([result,"=", dataRead(toPointer(ptr)), "-pivot", n].join(""))
}
}
function swapElements(a, b) {
if(order.length > 1) {
cacheLoop([a,b],false,[
"tmp=",dataRead("ptr0"),"\n",
dataWrite("ptr0", dataRead("ptr1")),"\n",
dataWrite("ptr1", "tmp")
].join(""))
} else {
code.push([
"ptr0=",toPointer(a),"\n",
"ptr1=",toPointer(b),"\n",
"tmp=",dataRead("ptr0"),"\n",
dataWrite("ptr0", dataRead("ptr1")),"\n",
dataWrite("ptr1", "tmp")
].join(""))
}
}
function tripleSwap(k, less, great) {
if(order.length > 1) {
cacheLoop([k,less,great], false, [
"tmp=",dataRead("ptr0"),"\n",
dataWrite("ptr0", dataRead("ptr1")),"\n",
dataWrite("ptr1", dataRead("ptr2")),"\n",
dataWrite("ptr2", "tmp")
].join(""))
code.push("++"+less, "--"+great)
} else {
code.push([
"ptr0=",toPointer(k),"\n",
"ptr1=",toPointer(less),"\n",
"ptr2=",toPointer(great),"\n",
"++",less,"\n",
"--",great,"\n",
"tmp=", dataRead("ptr0"), "\n",
dataWrite("ptr0", dataRead("ptr1")), "\n",
dataWrite("ptr1", dataRead("ptr2")), "\n",
dataWrite("ptr2", "tmp")
].join(""))
}
}
function swapAndDecrement(k, great) {
swapElements(k, great)
code.push("--"+great)
}
code.push("if(pivots_are_equal){")
//Pivots are equal case
code.push("for(k=less;k<=great;++k){")
comparePivot("comp", "k", 1)
code.push("if(comp===0){continue}")
code.push("if(comp<0){")
code.push("if(k!==less){")
swapElements("k", "less")
code.push("}")
code.push("++less")
code.push("}else{")
code.push("while(true){")
comparePivot("comp", "great", 1)
code.push("if(comp>0){")
code.push("great--")
code.push("}else if(comp<0){")
tripleSwap("k", "less", "great")
code.push("break")
code.push("}else{")
swapAndDecrement("k", "great")
code.push("break")
code.push("}")
code.push("}")
code.push("}")
code.push("}")
code.push("}else{")
//Pivots not equal case
code.push("for(k=less;k<=great;++k){")
comparePivot("comp_pivot1", "k", 1)
code.push("if(comp_pivot1<0){")
code.push("if(k!==less){")
swapElements("k", "less")
code.push("}")
code.push("++less")
code.push("}else{")
comparePivot("comp_pivot2", "k", 2)
code.push("if(comp_pivot2>0){")
code.push("while(true){")
comparePivot("comp", "great", 2)
code.push("if(comp>0){")
code.push("if(--great<k){break}")
code.push("continue")
code.push("}else{")
comparePivot("comp", "great", 1)
code.push("if(comp<0){")
tripleSwap("k", "less", "great")
code.push("}else{")
swapAndDecrement("k", "great")
code.push("}")
code.push("break")
code.push("}")
code.push("}")
code.push("}")
code.push("}")
code.push("}")
code.push("}")
//Move pivots to correct place
function storePivot(mem_dest, pivot_dest, pivot) {
if(order.length>1) {
cacheLoop([mem_dest, pivot_dest], true, [
dataWrite("ptr0", dataRead("ptr1")), "\n",
dataWrite("ptr1", ["pivot",pivot,"[pivot_ptr]"].join(""))
].join(""))
} else {
code.push(
dataWrite(toPointer(mem_dest), dataRead(toPointer(pivot_dest))),
dataWrite(toPointer(pivot_dest), "pivot"+pivot))
}
}
storePivot("left", "(less-1)", 1)
storePivot("right", "(great+1)", 2)
//Recursive sort call
function doSort(left, right) {
code.push([
"if((",right,"-",left,")<=",INSERTION_SORT_THRESHOLD,"){\n",
"insertionSort(", left, ",", right, ",data,offset,", shapeArgs(order.length).join(","), ")\n",
"}else{\n",
funcName, "(", left, ",", right, ",data,offset,", shapeArgs(order.length).join(","), ")\n",
"}"
].join(""))
}
doSort("left", "(less-2)")
doSort("(great+2)", "right")
//If pivots are equal, then early out
code.push("if(pivots_are_equal){")
cleanUp()
code.push("return")
code.push("}")
function walkPointer(ptr, pivot, body) {
if(order.length > 1) {
code.push(["__l",++labelCounter,":while(true){"].join(""))
cacheLoop([ptr], true, [
"if(", dataRead("ptr0"), "!==pivot", pivot, "[pivot_ptr]){break __l", labelCounter, "}"
].join(""))
code.push(body, "}")
} else {
code.push(["while(", dataRead(toPointer(ptr)), "===pivot", pivot, "){", body, "}"].join(""))
}
}
//Check bounds
code.push("if(less<index1&&great>index5){")
walkPointer("less", 1, "++less")
walkPointer("great", 2, "--great")
code.push("for(k=less;k<=great;++k){")
comparePivot("comp_pivot1", "k", 1)
code.push("if(comp_pivot1===0){")
code.push("if(k!==less){")
swapElements("k", "less")
code.push("}")
code.push("++less")
code.push("}else{")
comparePivot("comp_pivot2", "k", 2)
code.push("if(comp_pivot2===0){")
code.push("while(true){")
comparePivot("comp", "great", 2)
code.push("if(comp===0){")
code.push("if(--great<k){break}")
code.push("continue")
code.push("}else{")
comparePivot("comp", "great", 1)
code.push("if(comp<0){")
tripleSwap("k", "less", "great")
code.push("}else{")
swapAndDecrement("k", "great")
code.push("}")
code.push("break")
code.push("}")
code.push("}")
code.push("}")
code.push("}")
code.push("}")
code.push("}")
//Clean up and do a final sorting pass
cleanUp()
doSort("less", "great")
//Close off main loop
code.push("}return " + funcName)
//Compile and link
if(order.length > 1 && allocator) {
var compiled = new Function("insertionSort", "malloc", "free", code.join("\n"))
return compiled(insertionSort, allocator[0], allocator[1])
}
var compiled = new Function("insertionSort", code.join("\n"))
return compiled(insertionSort)
}
function compileSort(order, dtype) {
var code = ["'use strict'"]
var funcName = ["ndarraySortWrapper", order.join("d"), dtype].join("")
var funcArgs = [ "array" ]
code.push(["function ", funcName, "(", funcArgs.join(","), "){"].join(""))
//Unpack local variables from array
var vars = ["data=array.data,offset=array.offset|0,shape=array.shape,stride=array.stride"]
for(var i=0; i<order.length; ++i) {
vars.push(["s",i,"=stride[",i,"]|0,n",i,"=shape[",i,"]|0"].join(""))
}
var scratch_stride = new Array(order.length)
var nprod = []
for(var i=0; i<order.length; ++i) {
var k = order[i]
if(k === 0) {
continue
}
if(nprod.length === 0) {
scratch_stride[k] = "1"
} else {
scratch_stride[k] = nprod.join("*")
}
nprod.push("n"+k)
}
var p = -1, q = -1
for(var i=0; i<order.length; ++i) {
var j = order[i]
if(j !== 0) {
if(p > 0) {
vars.push(["d",j,"=s",j,"-d",p,"*n",p].join(""))
} else {
vars.push(["d",j,"=s",j].join(""))
}
p = j
}
var k = order.length-1-i
if(k !== 0) {
if(q > 0) {
vars.push(["e",k,"=s",k,"-e",q,"*n",q,
",f",k,"=",scratch_stride[k],"-f",q,"*n",q].join(""))
} else {
vars.push(["e",k,"=s",k,",f",k,"=",scratch_stride[k]].join(""))
}
q = k
}
}
//Declare local variables
code.push("var " + vars.join(","))
//Create arguments for subroutine
var sortArgs = ["0", "n0-1", "data", "offset"].concat(shapeArgs(order.length))
//Call main sorting routine
code.push([
"if(n0<=",INSERTION_SORT_THRESHOLD,"){",
"insertionSort(", sortArgs.join(","), ")}else{",
"quickSort(", sortArgs.join(","),
")}"
].join(""))
//Return
code.push("}return " + funcName)
//Link everything together
var result = new Function("insertionSort", "quickSort", code.join("\n"))
var insertionSort = createInsertionSort(order, dtype)
var quickSort = createQuickSort(order, dtype, insertionSort)
return result(insertionSort, quickSort)
}
module.exports = compileSort