1
This commit is contained in:
Alex Yatskov 2024-04-24 20:48:40 -07:00
parent b20e48c972
commit a8e6c907e5

View File

@ -37,26 +37,26 @@ end
local BracePair = {} local BracePair = {}
function BracePair.new(open, close) function BracePair.new(open, close)
local brace_pair = {open = open, close = close} local pair = {open = open, close = close}
return setmetatable(brace_pair, {__index = BracePair}) return setmetatable(pair, {__index = BracePair})
end end
function BracePair.from_brace(brace) function BracePair.from_brace(brace)
local brace_pairs = { local all_pairs = {
{'(', ')'}, {'(', ')'},
{'[', ']'}, {'[', ']'},
{'{', '}'}, {'{', '}'},
{'<', '>'}, {'<', '>'},
} }
for _, brace_pair in ipairs(brace_pairs) do for _, pair in ipairs(all_pairs) do
if brace_pair[1] == brace or brace_pair[2] == brace then if pair[1] == brace or pair[2] == brace then
return BracePair.new(brace_pair[1], brace_pair[2]) return BracePair.new(pair[1], pair[2])
end end
end end
end end
function BracePair:escaped() function BracePair:get_escaped()
local escape_func = function(brace_raw) local escape_func = function(brace_raw)
if brace_raw == '[' or brace_raw == ']' then if brace_raw == '[' or brace_raw == ']' then
return '\\' .. brace_raw return '\\' .. brace_raw
@ -71,11 +71,7 @@ function BracePair:escaped()
) )
end end
function BracePair:find_closest(backward, cursor) function BracePair:find_closest(backward)
if not cursor then
cursor = Cursor.get_current()
end
-- See flags: https://neovim.io/doc/user/builtin.html#search() -- See flags: https://neovim.io/doc/user/builtin.html#search()
local flags = 'Wcn' local flags = 'Wcn'
if backward then if backward then
@ -83,10 +79,11 @@ function BracePair:find_closest(backward, cursor)
end end
local ignore_func = function() local ignore_func = function()
local cursor = Cursor.get_current()
return cursor:is_string() return cursor:is_string()
end end
local escaped_pair = self:escaped() local escaped_pair = self:get_escaped()
local position = vim.fn.searchpairpos( local position = vim.fn.searchpairpos(
escaped_pair.open, escaped_pair.open,
'', '',
@ -113,9 +110,9 @@ function BraceStack.new()
end end
function BraceStack:update(brace) function BraceStack:update(brace)
local brace_pair = BracePair.from_brace(brace) local pair = BracePair.from_brace(brace)
if brace_pair then if pair then
if brace == brace_pair.close and self:top() == brace_pair.open then if brace == pair.close and self:top() == pair.open then
self:pop() self:pop()
else else
self:push(brace) self:push(brace)
@ -146,41 +143,40 @@ end
local BraceRange = {} local BraceRange = {}
function BraceRange.new(start_cursor, stop_cursor, brace_pair, brace_params) function BraceRange.new(start, stop, pair)
local brace_range = { local range = {
start_cursor = start_cursor, start = start,
stop_cursor = stop_cursor, stop = stop,
brace_pair = brace_pair, pair = pair,
brace_params = brace_params,
} }
return setmetatable(brace_range, {__index = BraceRange}) return setmetatable(range, {__index = BraceRange})
end end
function BraceRange.find_closest(brace_pair) function BraceRange.find_closest(pair)
local stop_cursor = brace_pair:find_closest(false) local stop = pair:find_closest(false)
if stop_cursor then if stop then
local start_cursor = brace_pair:find_closest(true) local start = pair:find_closest(true)
if start_cursor then if start then
return BraceRange.new(start_cursor, stop_cursor, brace_pair, {}) return BraceRange.new(start, stop, pair, {})
end end
end end
end end
function BraceRange.find_closest_any() function BraceRange.find_closest_any()
local brace_range_compare = function(brace_range_1, brace_range_2) local range_compare = function(range_1, range_2)
local cursor = Cursor:get_current() local cursor = Cursor:get_current()
local row_diff1 = cursor.row - brace_range_1.start_cursor.row local row_diff1 = cursor.row - range_1.start.row
local row_diff2 = cursor.row - brace_range_2.start_cursor.row local row_diff2 = cursor.row - range_2.start.row
if row_diff1 < row_diff2 then if row_diff1 < row_diff2 then
return -1 return -1
elseif row_diff1 > row_diff2 then elseif row_diff1 > row_diff2 then
return 1 return 1
end end
local col_diff1 = cursor.col - brace_range_1.start_cursor.col local col_diff1 = cursor.col - range_1.start.col
local col_diff2 = cursor.col - brace_range_2.start_cursor.col local col_diff2 = cursor.col - range_2.start.col
if col_diff1 < col_diff2 then if col_diff1 < col_diff2 then
return -1 return -1
elseif col_diff1 > col_diff2 then elseif col_diff1 > col_diff2 then
@ -190,67 +186,67 @@ function BraceRange.find_closest_any()
return 0 return 0
end end
local brace_ranges = {} local ranges = {}
for _, brace in ipairs({'(', '[', '{', '<'}) do for _, brace in ipairs({'(', '[', '{', '<'}) do
local brace_pair = BracePair.from_brace(brace) local pair = BracePair.from_brace(brace)
local brace_range = BraceRange.find_closest(brace_pair) local range = BraceRange.find_closest(pair)
if brace_range then if range then
table.insert(brace_ranges, brace_range) table.insert(ranges, range)
end end
end end
if #brace_ranges > 0 then if #ranges > 0 then
vim.fn.sort(brace_ranges, brace_range_compare) vim.fn.sort(ranges, range_compare)
return brace_ranges[1] return ranges[1]
end end
end end
function BraceRange:is_wrapped() function BraceRange:is_wrapped()
return self.start_cursor.row < self.stop_cursor.row return self.start.row < self.stop.row
end end
-- --
-- Arg -- Param
-- --
local Arg = {} local Param = {}
function Arg.new(text, brace_pair) function Param.new(text, pair)
local arg = { local param = {
text = text, text = text,
brace_pair = brace_pair, pair = pair,
} }
return setmetatable(arg, {__index = Arg}) return setmetatable(param, {__index = Param})
end end
function Arg:append(char) function Param:append(char)
self.text = self.text .. char self.text = self.text .. char
end end
-- --
-- ArgList -- ParamList
-- --
local ArgList = {} local ParamList = {}
function ArgList.new() function ParamList.new()
local arg_list = { local params = {
arg = nil, current = nil,
args = {}, parsed = {},
} }
return setmetatable(arg_list, {__index = ArgList}) return setmetatable(params, {__index = ParamList})
end end
function ArgList:flush() function ParamList:flush()
if self.arg then if self.current then
table.insert(self.args, self.arg) table.insert(self.parsed, self.current)
self.arg = nil self.current = nil
end end
end end
function ArgList:update(char, brace_stack, brace_range, cursor) function ParamList:update(char, brace_stack, range, cursor)
if not cursor:is_string() then if not cursor:is_string() then
brace_stack:update(char) brace_stack:update(char)
if brace_stack:empty() and char == ',' then if brace_stack:empty() and char == ',' then
@ -259,31 +255,31 @@ function ArgList:update(char, brace_stack, brace_range, cursor)
end end
end end
if self.arg then if self.current then
self.arg:append(char) self.current:append(char)
else else
self.arg = Arg.new(char, brace_range) self.current = Param.new(char, range)
end end
end end
function ArgList:parse(brace_range) function ParamList:parse(range)
local brace_stack = BraceStack:new() local brace_stack = BraceStack:new()
for row = brace_range.start_cursor.row, brace_range.stop_cursor.row do for row = range.start.row, range.stop.row do
local line = vim.fn.getline(row) local line = vim.fn.getline(row)
local start_col = 1 local start_col = 1
if row == brace_range.start_cursor.row then if row == range.start.row then
start_col = brace_range.start_cursor.col + 1 start_col = range.start.col + 1
end end
local stop_col = #line local stop_col = #line
if row == brace_range.stop_cursor.row then if row == range.stop.row then
stop_col = brace_range.stop_cursor.col - 1 stop_col = range.stop.col - 1
end end
for col = start_col, stop_col do for col = start_col, stop_col do
self:update(line:sub(col, col), brace_stack, brace_range, Cursor.new(row, col)) self:update(line:sub(col, col), brace_stack, range, Cursor.new(row, col))
end end
end end
@ -302,28 +298,28 @@ function WrapContext.new(opts)
indent = '', indent = '',
prefix = '', prefix = '',
suffix = '', suffix = '',
brace_range = nil, range = nil,
arg_list = nil, params = nil,
} }
return setmetatable(wrap_context, {__index = WrapContext}) return setmetatable(wrap_context, {__index = WrapContext})
end end
function WrapContext:parse() function WrapContext:parse()
self.brace_range = BraceRange.find_closest_any() self.range = BraceRange.find_closest_any()
if not self.brace_range then if not self.range then
return false return false
end end
local first_line = vim.fn.getline(self.brace_range.start_cursor.row) local first_line = vim.fn.getline(self.range.start.row)
self.indent = first_line:match('^(%s*)') self.indent = first_line:match('^(%s*)')
self.prefix = first_line:sub(#self.indent + 1, self.brace_range.start_cursor.col) self.prefix = first_line:sub(#self.indent + 1, self.range.start.col)
local last_line = vim.fn.getline(self.brace_range.stop_cursor.row) local last_line = vim.fn.getline(self.range.stop.row)
self.suffix = last_line:sub(self.brace_range.stop_cursor.col) self.suffix = last_line:sub(self.range.stop.col)
self.arg_list = ArgList.new() self.params = ParamList.new()
self.arg_list:parse(self.brace_range) self.params:parse(self.range)
return true return true
end end
@ -331,41 +327,37 @@ end
function WrapContext:wrap() function WrapContext:wrap()
local line = self.indent .. self.prefix local line = self.indent .. self.prefix
for i, arg in ipairs(self.arg_list.args) do for i, param in ipairs(self.params.parsed) do
line = line .. arg.text line = line .. param.text
if i < #self.arg_list.args then if i < #self.params.parsed then
line = line .. ', ' line = line .. ', '
end end
end end
line = line .. self.suffix line = line .. self.suffix
vim.fn.setline(self.brace_range.start_cursor.row, line) vim.fn.setline(self.range.start.row, line)
vim.fn.execute(string.format( vim.fn.execute(string.format('%d,%dd_', self.range.start.row + 1, self.range.stop.row))
'%d,%dd_',
self.brace_range.start_cursor.row + 1,
self.brace_range.stop_cursor.row
))
end end
function WrapContext:unwrap() function WrapContext:unwrap()
vim.fn.setline( vim.fn.setline(
self.brace_range.start_cursor.row, self.range.start.row,
self.indent .. self.prefix self.indent .. self.prefix
) )
local cursor = nil local cursor = nil
local row = self.brace_range.start_cursor.row local row = self.range.start.row
for i, arg in ipairs(self.arg_list.args) do for i, param in ipairs(self.params.parsed) do
local on_last_arg = i == #self.arg_list.args local on_last_param = i == #self.params.parsed
local line = self.indent .. arg.text local line = self.indent .. param.text
if self.opts.tail_comma or not on_last_arg then if self.opts.tail_comma or not on_last_param then
line = line .. ',' line = line .. ','
end end
if on_last_arg and not self.opts.wrap_closing_brace then if on_last_param and not self.opts.wrap_closing_brace then
line = line .. self.suffix line = line .. self.suffix
end end
@ -391,7 +383,7 @@ function WrapContext:unwrap()
end end
function WrapContext:toggle() function WrapContext:toggle()
if self.brace_range:is_wrapped() then if self.range:is_wrapped() then
self:wrap() self:wrap()
else else
self:unwrap() self:unwrap()