mpq loader interface cleanup

This commit is contained in:
Alex Yatskov 2018-12-15 20:48:20 -08:00
parent d136631eb7
commit 089780f619
2 changed files with 62 additions and 53 deletions

View File

@ -29,17 +29,17 @@ import (
type File interface { type File interface {
Read(data []byte) (int, error) Read(data []byte) (int, error)
GetSize() (int, error) GetSize() int
Close() error Close() error
} }
type Archive interface { type Archive interface {
OpenFile(path string) (File, error) OpenFile(path string) (File, error)
GetPaths() ([]string, error) GetPaths() []string
Close() error Close() error
} }
func New(path string) (Archive, error) { func NewFromFile(path string) (Archive, error) {
cs := C.CString(path) cs := C.CString(path)
defer C.free(unsafe.Pointer(cs)) defer C.free(unsafe.Pointer(cs))
@ -48,6 +48,11 @@ func New(path string) (Archive, error) {
return nil, fmt.Errorf("failed to open archive (%d)", getLastError()) return nil, fmt.Errorf("failed to open archive (%d)", getLastError())
} }
if err := a.buildPathMap(); err != nil {
a.Close()
return nil, err
}
return a, nil return a, nil
} }
@ -58,12 +63,7 @@ type file struct {
} }
func (f *file) Read(data []byte) (int, error) { func (f *file) Read(data []byte) (int, error) {
size, err := f.GetSize() bytesRemaining := f.size - f.offset
if err != nil {
return 0, err
}
bytesRemaining := size - f.offset
if bytesRemaining == 0 { if bytesRemaining == 0 {
return 0, io.EOF return 0, io.EOF
} }
@ -82,18 +82,8 @@ func (f *file) Read(data []byte) (int, error) {
return bytesRead, nil return bytesRead, nil
} }
func (f *file) GetSize() (int, error) { func (f *file) GetSize() int {
if f.size != math.MaxUint32 { return f.size
return f.size, nil
}
size := int(C.SFileGetFileSize(f.handle, nil))
if size == -1 {
return 0, fmt.Errorf("failed to get file size (%d)", getLastError())
}
f.size = size
return size, nil
} }
func (f *file) Close() error { func (f *file) Close() error {
@ -108,9 +98,19 @@ func (f *file) Close() error {
return nil return nil
} }
func (f *file) buildSize() error {
size := int(C.SFileGetFileSize(f.handle, nil))
if size == -1 {
return fmt.Errorf("failed to get file size (%d)", getLastError())
}
f.size = size
return nil
}
type archive struct { type archive struct {
handle unsafe.Pointer handle unsafe.Pointer
paths []string paths map[string]string
} }
func (a *archive) Close() error { func (a *archive) Close() error {
@ -125,7 +125,11 @@ func (a *archive) Close() error {
} }
func (a *archive) OpenFile(path string) (File, error) { func (a *archive) OpenFile(path string) (File, error) {
cs := C.CString(strings.Replace(path, string(os.PathSeparator), "\\", -1)) if pathInt, ok := a.paths[path]; ok {
path = pathInt
}
cs := C.CString(path)
defer C.free(unsafe.Pointer(cs)) defer C.free(unsafe.Pointer(cs))
file := &file{size: math.MaxUint32} file := &file{size: math.MaxUint32}
@ -133,34 +137,51 @@ func (a *archive) OpenFile(path string) (File, error) {
return nil, fmt.Errorf("failed to open file (%d)", getLastError()) return nil, fmt.Errorf("failed to open file (%d)", getLastError())
} }
if err := file.buildSize(); err != nil {
file.Close()
return nil, err
}
return file, nil return file, nil
} }
func (a *archive) GetPaths() ([]string, error) { func (a *archive) GetPaths() []string {
if len(a.paths) > 0 { var extPaths []string
return a.paths, nil for extPath := range a.paths {
extPaths = append(extPaths, extPath)
} }
return extPaths
}
func (a *archive) buildPathMap() error {
f, err := a.OpenFile("(listfile)") f, err := a.OpenFile("(listfile)")
if err != nil { if err != nil {
return nil, err return err
} }
defer f.Close() defer f.Close()
var buff bytes.Buffer var buff bytes.Buffer
if _, err := io.Copy(&buff, f); err != nil { if _, err := io.Copy(&buff, f); err != nil {
return nil, err return err
} }
for _, line := range strings.Split(string(buff.Bytes()), "\r\n") { a.paths = make(map[string]string)
line = strings.TrimSpace(line)
line = strings.Replace(line, "\\", string(os.PathSeparator), -1) lines := strings.Split(string(buff.Bytes()), "\r\n")
if len(line) > 0 { for _, line := range lines {
a.paths = append(a.paths, line) pathInt := strings.TrimSpace(line)
if len(pathInt) > 0 {
pathExt := santizePath(pathInt)
a.paths[pathExt] = pathInt
} }
} }
return a.paths, nil return nil
}
func santizePath(path string) string {
return strings.ToLower(strings.Replace(path, "\\", string(os.PathSeparator), -1))
} }
func getLastError() uint { func getLastError() uint {

View File

@ -6,25 +6,19 @@ import (
"io" "io"
"os" "os"
"path" "path"
"strings"
"github.com/FooSoft/lazarus/formats/mpq" "github.com/FooSoft/lazarus/formats/mpq"
"github.com/bmatcuk/doublestar" "github.com/bmatcuk/doublestar"
) )
func list(mpqPath, filter string) error { func list(mpqPath, filter string) error {
arch, err := mpq.New(mpqPath) arch, err := mpq.NewFromFile(mpqPath)
if err != nil { if err != nil {
return err return err
} }
defer arch.Close() defer arch.Close()
resPaths, err := arch.GetPaths() for _, resPath := range arch.GetPaths() {
if err != nil {
return err
}
for _, resPath := range resPaths {
match, err := doublestar.Match(filter, resPath) match, err := doublestar.Match(filter, resPath)
if err != nil { if err != nil {
return err return err
@ -38,19 +32,14 @@ func list(mpqPath, filter string) error {
return nil return nil
} }
func extract(mpqPath, filter, targetDir string, lowercase bool) error { func extract(mpqPath, filter, targetDir string) error {
arch, err := mpq.New(mpqPath) arch, err := mpq.NewFromFile(mpqPath)
if err != nil { if err != nil {
return err return err
} }
defer arch.Close() defer arch.Close()
resPaths, err := arch.GetPaths() for _, resPath := range arch.GetPaths() {
if err != nil {
return err
}
for _, resPath := range resPaths {
match, err := doublestar.Match(filter, resPath) match, err := doublestar.Match(filter, resPath)
if err != nil { if err != nil {
return err return err
@ -68,7 +57,7 @@ func extract(mpqPath, filter, targetDir string, lowercase bool) error {
} }
defer resFile.Close() defer resFile.Close()
sysPath := path.Join(targetDir, strings.ToLower(resPath)) sysPath := path.Join(targetDir, resPath)
if err := os.MkdirAll(path.Dir(sysPath), 0777); err != nil { if err := os.MkdirAll(path.Dir(sysPath), 0777); err != nil {
return err return err
} }
@ -93,7 +82,6 @@ func main() {
var ( var (
filter = flag.String("filter", "**", "wildcard file filter") filter = flag.String("filter", "**", "wildcard file filter")
targetDir = flag.String("target", ".", "target directory") targetDir = flag.String("target", ".", "target directory")
lowercase = flag.Bool("lowercase", true, "extract with lowercase paths")
) )
flag.Usage = func() { flag.Usage = func() {
@ -118,7 +106,7 @@ func main() {
} }
case "extract": case "extract":
for i := 1; i < flag.NArg(); i++ { for i := 1; i < flag.NArg(); i++ {
if err := extract(flag.Arg(i), *filter, *targetDir, *lowercase); err != nil { if err := extract(flag.Arg(i), *filter, *targetDir); err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
} }
} }