diff --git a/util.go b/util.go index ac3e593..86f1cf8 100644 --- a/util.go +++ b/util.go @@ -26,6 +26,7 @@ import ( "log" "math" "strconv" + "sync" "github.com/kellydunn/golang-geo" ) @@ -159,7 +160,7 @@ func project(entries records, features featureMap, modes modeMap, featureName st return projection } -func computeRecordGeo(entries records, context queryContext) { +func computeRecordsGeo(entries records, context queryContext) { distUserMin := math.MaxFloat64 distUserMax := 0.0 @@ -195,56 +196,75 @@ func computeRecordGeo(entries records, context queryContext) { } } -func computeRecordCompat(entries records, context queryContext) { - for index := range entries { - entry := &entries[index] +func computeRecordCompat(entry *record, context queryContext, wg *sync.WaitGroup) { + historyRows, err := db.Query("SELECT id FROM history WHERE reviewId = (?)", entry.id) + if err != nil { + log.Fatal(err) + } + defer historyRows.Close() - historyRows, err := db.Query("SELECT id FROM history WHERE reviewId = (?)", entry.id) + var groupSum float64 + var groupCount int + + for historyRows.Next() { + var historyId int + if err := historyRows.Scan(&historyId); err != nil { + log.Fatal(err) + } + + groupRows, err := db.Query("SELECT categoryId, categoryValue FROM historyGroups WHERE historyId = (?)", historyId) if err != nil { log.Fatal(err) } - defer historyRows.Close() + defer groupRows.Close() - var groupSum float64 - var groupCount int + recordProfile := make(featureMap) + for groupRows.Next() { + var categoryId int + var categoryValue float64 - for historyRows.Next() { - var historyId int - if err := historyRows.Scan(&historyId); err != nil { + if err := groupRows.Scan(&categoryId, &categoryValue); err != nil { log.Fatal(err) } - groupRows, err := db.Query("SELECT categoryId, categoryValue FROM historyGroups WHERE historyId = (?)", historyId) - if err != nil { - log.Fatal(err) - } - defer groupRows.Close() - - recordProfile := make(featureMap) - for groupRows.Next() { - var categoryId int - var categoryValue float64 - - if err := groupRows.Scan(&categoryId, &categoryValue); err != nil { - log.Fatal(err) - } - - recordProfile[strconv.Itoa(categoryId)] = categoryValue - } - if err := groupRows.Err(); err != nil { - log.Fatal(err) - } - - groupSum += distance(recordProfile, context.profile) - groupCount++ + recordProfile[strconv.Itoa(categoryId)] = categoryValue } - if err := historyRows.Err(); err != nil { + if err := groupRows.Err(); err != nil { log.Fatal(err) } - if groupCount > 0 { - entry.compatibility = groupSum / float64(groupCount) + groupSum += distance(recordProfile, context.profile) + groupCount++ + } + if err := historyRows.Err(); err != nil { + log.Fatal(err) + } + + if groupCount > 0 { + entry.compatibility = groupSum / float64(groupCount) + } + + wg.Done() +} + +func computeRecordsCompat(entries records, context queryContext) { + count := len(entries) + limit := 32 + + for i := 0; i < count; i += limit { + batch := count - i + if batch > limit { + batch = limit } + + var wg sync.WaitGroup + wg.Add(batch) + + for j := 0; j < batch; j++ { + go computeRecordCompat(&entries[i+j], context, &wg) + } + + wg.Wait() } } @@ -296,8 +316,8 @@ func getRecords(context queryContext) records { log.Fatal(err) } - computeRecordCompat(entries, context) - computeRecordGeo(entries, context) + computeRecordsCompat(entries, context) + computeRecordsGeo(entries, context) return entries }