diff --git a/lua/hflip.lua b/lua/hflip.lua index dd77e87..5b9e8d1 100644 --- a/lua/hflip.lua +++ b/lua/hflip.lua @@ -6,48 +6,35 @@ local function file_exists(path) return vim.loop.fs_stat(path) ~= nil end -local function locate_flip_path(name, exts) - for key, _ in pairs(exts) do - local path = name .. key - if file_exists(path) then - return path - end +local function hflip() + local exts = {'.h', '.hpp', '.c', '.cpp', '.inl'} + for i = 1, #exts do + table.insert(exts, exts[i]:upper()) + end - local path_upper = name .. key:upper() - if file_exists(path_upper) then - return path_upper + local path = vim.api.nvim_buf_get_name(vim.api.nvim_get_current_buf()) + local name, ext = split_ext(path) + if not ext then + return + end + + local ext_index = nil + for i, e in ipairs(exts) do + if e == ext:lower() then + ext_index = i + break end end -end + if not ext_index then + return + end -local function hflip() - local header_exts = { - ['.h'] = true, - ['.hpp'] = true, - ['.lua'] = true, - } - - local source_exts = { - ['.c'] = true, - ['.cpp'] = true - } - - local index = vim.api.nvim_get_current_buf() - local path = vim.api.nvim_buf_get_name(index) - local name, ext = split_ext(path) - - if ext ~= nil then - local ext_lower = ext:lower() - - local flip_path - if header_exts[ext_lower] then - flip_path = locate_flip_path(name, source_exts) - elseif source_exts[ext_lower] then - flip_path = locate_flip_path(name, header_exts) - end - - if flip_path then - vim.cmd(string.format('e %s', flip_path)) + for i = 1, #exts do + local j = (i + ext_index - 1) % #exts + 1 + local next_path = name .. exts[j] + if file_exists(next_path) then + vim.cmd(string.format('e %s', next_path)) + break end end end