prosemirror-flat-list
Version:
Powerful list support for ProseMirror
102 lines (87 loc) • 2.84 kB
text/typescript
import { NodeRange } from 'prosemirror-model'
import type { Command } from 'prosemirror-state'
import { findWrapping } from 'prosemirror-transform'
import type { ListAttributes } from '../types'
import { getListType } from '../utils/get-list-type'
import { isListNode } from '../utils/is-list-node'
import { setNodeAttributes } from '../utils/set-node-attributes'
/**
* The list node attributes or a callback function to take the current
* selection block range and return list node attributes. If this callback
* function returns null, the command won't do anything.
*
* @public @group Commands
*/
export type WrapInListGetAttrs<T extends ListAttributes> =
| T
| ((range: NodeRange) => T | null)
/**
* Returns a command function that wraps the selection in a list with the given
* type and attributes.
*
* @public @group Commands
*/
export function createWrapInListCommand<
T extends ListAttributes = ListAttributes,
>(getAttrs: WrapInListGetAttrs<T>): Command {
const wrapInList: Command = (state, dispatch): boolean => {
const { $from, $to } = state.selection
let range = $from.blockRange($to)
if (!range) {
return false
}
if (
rangeAllowInlineContent(range) &&
isListNode(range.parent) &&
range.depth > 0 &&
range.startIndex === 0
) {
range = new NodeRange($from, $to, range.depth - 1)
}
const attrs: T | null =
typeof getAttrs === 'function' ? getAttrs(range) : getAttrs
if (!attrs) {
return false
}
const { parent, startIndex, endIndex, depth } = range
const tr = state.tr
const listType = getListType(state.schema)
for (let i = endIndex - 1; i >= startIndex; i--) {
const node = parent.child(i)
if (isListNode(node)) {
const oldAttrs: T = node.attrs as T
const newAttrs: T = { ...oldAttrs, ...attrs }
setNodeAttributes(tr, $from.posAtIndex(i, depth), oldAttrs, newAttrs)
} else {
const beforeNode = $from.posAtIndex(i, depth)
const afterNode = $from.posAtIndex(i + 1, depth)
let nodeStart = beforeNode + 1
let nodeEnd = afterNode - 1
if (nodeStart > nodeEnd) {
;[nodeStart, nodeEnd] = [nodeEnd, nodeStart]
}
const range = new NodeRange(
tr.doc.resolve(nodeStart),
tr.doc.resolve(nodeEnd),
depth,
)
const wrapping = findWrapping(range, listType, attrs)
if (wrapping) {
tr.wrap(range, wrapping)
}
}
}
dispatch?.(tr)
return true
}
return wrapInList
}
function rangeAllowInlineContent(range: NodeRange): boolean {
const { parent, startIndex, endIndex } = range
for (let i = startIndex; i < endIndex; i++) {
if (parent.child(i).inlineContent) {
return true
}
}
return false
}