diff --git a/server.go b/server.go index 8604a02..855a50d 100644 --- a/server.go +++ b/server.go @@ -117,7 +117,12 @@ func handleExecuteQuery(rw http.ResponseWriter, req *http.Request) { geo = &geoData{request.Geo.Latitude, request.Geo.Longitude} } - allEntries := getRecords(queryContext{geo, request.Profile, request.WalkingDist}) + allEntries, err := fetchRecords(db, queryContext{geo, request.Profile, request.WalkingDist}) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + features := fixFeatures(request.Features) modes := fixModes(request.Modes) @@ -241,13 +246,13 @@ func handleAddCategory(rw http.ResponseWriter, req *http.Request) { return } - affectedRows, err := result.RowsAffected() + rows, err := result.RowsAffected() if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return } - response.Success = affectedRows > 0 + response.Success = rows > 0 response.Id = int(insertId) } @@ -315,6 +320,7 @@ func handleAccessReview(rw http.ResponseWriter, req *http.Request) { } if rowsAffected == 0 || len(request.Profile) == 0 { + http.Error(rw, err.Error(), http.StatusInternalServerError) return } diff --git a/util.go b/util.go index 1ddccf8..f38a6b4 100644 --- a/util.go +++ b/util.go @@ -23,10 +23,9 @@ package main import ( - "log" + "database/sql" "math" "strconv" - "sync" "github.com/kellydunn/golang-geo" ) @@ -69,7 +68,7 @@ func fixModes(modes map[string]string) map[string]modeType { return fixedModes } -func similarity(features1 map[string]float64, features2 map[string]float64) float64 { +func semanticSimilarity(features1 map[string]float64, features2 map[string]float64) float64 { var result float64 for key, value1 := range features1 { @@ -81,7 +80,7 @@ func similarity(features1 map[string]float64, features2 map[string]float64) floa return result } -func compare(features1 map[string]float64, features2 map[string]float64, modes map[string]modeType) float64 { +func semanticCompare(features1 map[string]float64, features2 map[string]float64, modes map[string]modeType) float64 { var result float64 for key, value1 := range features1 { @@ -92,8 +91,6 @@ func compare(features1 map[string]float64, features2 map[string]float64, modes m result += 1 - math.Abs(value1-value2) case modeTypeProd: result += value1 * value2 - default: - log.Fatal("unsupported compare mode") } } @@ -102,7 +99,7 @@ func compare(features1 map[string]float64, features2 map[string]float64, modes m func walkMatches(entries []record, features map[string]float64, modes map[string]modeType, minScore float64, callback func(record, float64)) { for _, entry := range entries { - if score := compare(features, entry.features, modes); score >= minScore { + if score := semanticCompare(features, entry.features, modes); score >= minScore { callback(entry, score) } } @@ -163,7 +160,7 @@ func project(entries []record, features map[string]float64, modes map[string]mod return projections } -func computeRecordsGeo(entries []record, context queryContext) { +func computeRecordGeo(entries []record, context queryContext) { distUserMin := math.MaxFloat64 distUserMax := 0.0 @@ -199,91 +196,74 @@ func computeRecordsGeo(entries []record, context queryContext) { } } -func computeRecordCompat(entry *record, context queryContext, wg *sync.WaitGroup) { - defer wg.Done() +func computeRecordCompat(db *sql.DB, entries []record, context queryContext) error { + for i := range entries { + entry := &entries[i] - historyRows, err := db.Query("SELECT id FROM history WHERE reviewId = (?)", entry.Id) - if err != nil { - log.Fatal(err) - } - defer historyRows.Close() - - var ( - groupSum float64 - groupCount int - ) - - for historyRows.Next() { - var historyId int - if err := historyRows.Scan(&historyId); err != nil { - log.Fatal(err) - } - - groupRows, err := db.Query("SELECT categoryId, categoryValue FROM historyGroups WHERE historyId = (?)", historyId) + historyRows, err := db.Query("SELECT id FROM history WHERE reviewId = (?)", entry.Id) if err != nil { - log.Fatal(err) + return err } - defer groupRows.Close() + defer historyRows.Close() - recordProfile := make(map[string]float64) - for groupRows.Next() { - var ( - categoryId int - categoryValue float64 - ) + var ( + groupSum float64 + groupCount int + ) - if err := groupRows.Scan(&categoryId, &categoryValue); err != nil { - log.Fatal(err) + for historyRows.Next() { + var historyId int + if err := historyRows.Scan(&historyId); err != nil { + return err } - recordProfile[strconv.Itoa(categoryId)] = categoryValue + groupRows, err := db.Query("SELECT categoryId, categoryValue FROM historyGroups WHERE historyId = (?)", historyId) + if err != nil { + return err + } + defer groupRows.Close() + + recordProfile := make(map[string]float64) + for groupRows.Next() { + var ( + categoryId int + categoryValue float64 + ) + + if err := groupRows.Scan(&categoryId, &categoryValue); err != nil { + return err + } + + recordProfile[strconv.Itoa(categoryId)] = categoryValue + } + if err := groupRows.Err(); err != nil { + return err + } + + groupSum += semanticSimilarity(recordProfile, context.profile) + groupCount++ } - if err := groupRows.Err(); err != nil { - log.Fatal(err) + if err := historyRows.Err(); err != nil { + return err } - groupSum += similarity(recordProfile, context.profile) - groupCount++ - } - if err := historyRows.Err(); err != nil { - log.Fatal(err) + if groupCount > 0 { + entry.Compatibility = groupSum / float64(groupCount) + } } - if groupCount > 0 { - entry.Compatibility = groupSum / float64(groupCount) - } + return nil } -func computeRecordsCompat(entries []record, context queryContext) { - count := len(entries) - limit := 32 - - for i := 0; i < count; i += limit { - batch := count - i - if batch > limit { - batch = limit - } - - var wg sync.WaitGroup - wg.Add(batch) - - for j := 0; j < batch; j++ { - go computeRecordCompat(&entries[i+j], context, &wg) - } - - wg.Wait() - } -} - -func getRecords(context queryContext) []record { - recordRows, err := db.Query("SELECT name, url, delicious, accommodating, affordable, atmospheric, latitude, longitude, closestStnDist, closestStnName, accessCount, id FROM reviews") +func fetchRecords(db *sql.DB, context queryContext) ([]record, error) { + rows, err := db.Query("SELECT name, url, delicious, accommodating, affordable, atmospheric, latitude, longitude, closestStnDist, closestStnName, accessCount, id FROM reviews") if err != nil { - log.Fatal(err) + return nil, err } - defer recordRows.Close() + defer rows.Close() var entries []record - for recordRows.Next() { + for rows.Next() { var ( name, url, closestStn string delicious, accommodating, affordable, atmospheric float64 @@ -291,7 +271,7 @@ func getRecords(context queryContext) []record { accessCount, id int ) - recordRows.Scan( + rows.Scan( &name, &url, &delicious, @@ -322,12 +302,14 @@ func getRecords(context queryContext) []record { entries = append(entries, entry) } - if err := recordRows.Err(); err != nil { - log.Fatal(err) + if err := rows.Err(); err != nil { + return nil, err } - computeRecordsCompat(entries, context) - computeRecordsGeo(entries, context) + computeRecordGeo(entries, context) + if err := computeRecordCompat(db, entries, context); err != nil { + return nil, err + } - return entries + return entries, nil }