diff --git a/lua/argonaut.lua b/lua/argonaut.lua index a7881f5..e6f4ac9 100644 --- a/lua/argonaut.lua +++ b/lua/argonaut.lua @@ -22,7 +22,7 @@ local function get_cursor_pos() return {row = row, col = col} end -local function is_literal(pos) +local function is_string_literal(pos) if not pos then pos = get_cursor_pos() end @@ -33,13 +33,36 @@ local function is_literal(pos) return syn_attr:find('String$') end -local function find_range(brace_pair) - local row1, col1 = unpack(vim.fn.searchpairpos(brace_pair[1], '', brace_pair[2], 'Wnb', is_literal)) +local function find_brace_alt(brace) + local brace_pairs = {{'(', ')'}, {'[', ']'}, {'{', '}'}, {'<', '>'}} + + for _, brace_pair in ipairs(brace_pairs) do + if brace_pair[1] == brace then + return brace_pair[2], true + elseif brace_pair[2] == brace then + return brace_pair[1], false + end + end +end + +local function escape_brace(brace) + if brace == '[' or brace == ']' then + return '\\' .. brace + else + return brace + end +end + +local function find_brace_range(brace) + local brace_alt, _ = find_brace_alt(brace) + assert(brace_alt) + + local row1, col1 = unpack(vim.fn.searchpairpos(escape_brace(brace), '', escape_brace(brace_alt), 'Wnb', is_string_literal)) if row1 > 0 and col1 > 0 then - local row2, col2 = unpack(vim.fn.searchpairpos(brace_pair[1], '', brace_pair[2], 'Wcn', is_literal)) + local row2, col2 = unpack(vim.fn.searchpairpos(escape_brace(brace), '', escape_brace(brace_alt), 'Wcn', is_string_literal)) if row2 > 0 and col2 > 0 then return { - brace_pair = brace_pair, + brace = brace, row1 = row1, col1 = col1, row2 = row2, @@ -49,34 +72,33 @@ local function find_range(brace_pair) end end -local function find_all_ranges(brace_pairs) - local ranges = {} - for _, brace_pair in ipairs(brace_pairs) do - local range = find_range(brace_pair) - if range then - table.insert(ranges, find_range(brace_pair)) +local function find_all_brace_ranges(braces) + local brace_ranges = {} + for _, brace in ipairs(braces) do + local brace_range = find_brace_range(brace) + if brace_range then + table.insert(brace_ranges, brace_range) end end - if #ranges > 0 then - return ranges + if #brace_ranges > 0 then + return brace_ranges end end -local function find_closest_range(brace_pairs) - local function compare_ranges(range1, range2) - local pos = get_cursor_pos() - - local row_diff1 = pos.row - range1.row1 - local row_diff2 = pos.row - range2.row1 +local function find_closest_brace_range(braces) + local cursor_pos = get_cursor_pos() + local compare_brace_ranges = function(brace_range_1, brace_range_2) + local row_diff1 = cursor_pos.row - brace_range_1.row1 + local row_diff2 = cursor_pos.row - brace_range_2.row1 if row_diff1 < row_diff2 then return -1 elseif row_diff1 > row_diff2 then return 1 end - local col_diff1 = pos.col - range1.col1 - local col_diff2 = pos.col - range2.col1 + local col_diff1 = cursor_pos.col - brace_range_1.col1 + local col_diff2 = cursor_pos.col - brace_range_2.col1 if col_diff1 < col_diff2 then return -1 elseif col_diff1 > col_diff2 then @@ -86,97 +108,84 @@ local function find_closest_range(brace_pairs) return 0 end - local ranges = find_all_ranges(brace_pairs) - if ranges then - return vim.fn.sort(ranges, compare_ranges)[1] + local brace_ranges = find_all_brace_ranges(braces) + if brace_ranges then + return vim.fn.sort(brace_ranges, compare_brace_ranges)[1] end end -local function parse_range(range) - local elements = {} - for row = range.row1, range.row2 do - local line = vim.fn.getline(row) - - local col1 = 0 - if row == range.row1 then - col1 = range.col1 + 1 - end - - local col2 = #line - if row == range.row2 then - col2 = range.col2 - 1 - end - - for i = col1, col2 do - table.insert(elements, { - char = line:sub(i, i), ---@diagnostic disable-line: undefined-field - literal = is_literal({row = row, col = i}), - }) - end - end - - local params = {} - local param = '' - local flush_param = function() - if #param > 0 then - table.insert(params, param) - param = '' +local function parse_brace_range(brace_range) + local brace_range_params = {} + local brace_range_param = '' + local flush_brace_range_param = function() + if #brace_range_param > 0 then + brace_range_param, _ = brace_range_param:gsub('^%s*([%S%s]-)%s*$', '%1') + table.insert(brace_range_params, brace_range_param) + brace_range_param = '' end end local brace_stack = {} - local update_brace_stack = function(char) + local update_brace_stack = function(c) local brace_stack_size = #brace_stack - - local brace_pairs_forward = { - ['('] = ')', - ['['] = ']', - ['{'] = '}', - ['<'] = '>' - } - local brace_pairs_backward = { - [')'] = '(', - [']'] = '[', - ['}'] = '{', - ['>'] = '<' - } - - if brace_stack_size > 0 and brace_stack[brace_stack_size] == brace_pairs_backward[char] then + local brace_alt, brace_open = find_brace_alt(c) + if brace_stack_size > 0 and brace_alt == brace_stack[brace_stack_size] and not brace_open then table.remove(brace_stack, brace_stack_size) - elseif brace_pairs_forward[char] then - table.insert(brace_stack, char) + elseif brace_alt then + table.insert(brace_stack, c) end end - if #elements > 0 then - for _, element in ipairs(elements) do - local concat = true - if not element.literal then - update_brace_stack(element.char) - if #brace_stack == 0 and element.char == ',' then - flush_param() - concat = false + local brace_range_elements = {} + for row = brace_range.row1, brace_range.row2 do + local line = vim.fn.getline(row) + + local col1 = 0 + if row == brace_range.row1 then + col1 = brace_range.col1 + 1 + end + + local col2 = #line + if row == brace_range.row2 then + col2 = brace_range.col2 - 1 + end + + for i = col1, col2 do + table.insert(brace_range_elements, { + char = line:sub(i, i), ---@diagnostic disable-line: undefined-field + literal = is_string_literal({row = row, col = i}), + }) + end + end + + if #brace_range_elements > 0 then + for _, brace_range_element in ipairs(brace_range_elements) do + local append = true + if not brace_range_element.literal then + update_brace_stack(brace_range_element.char) + if #brace_stack == 0 and brace_range_element.char == ',' then + flush_brace_range_param() + append = false end end - if concat then - param = param .. element.char + if append then + brace_range_param = brace_range_param .. brace_range_element.char end end - flush_param() + flush_brace_range_param() end - for _, p in ipairs(params) do + for _, p in ipairs(brace_range_params) do print(p) end end local function reflow() - local brace_pairs = {{'(', ')'}, {'\\[', '\\]'}, {'{', '}'}, {'<', '>'}} - local range = find_closest_range(brace_pairs) - if range then - parse_range(range) + local brace_range = find_closest_brace_range({'(', '[', '{', '<'}) + if brace_range then + parse_brace_range(brace_range) end end