1
This commit is contained in:
Alex Yatskov 2015-08-24 15:42:16 +09:00
parent d253afeaf7
commit 83d709398c
2 changed files with 72 additions and 84 deletions

View File

@ -117,7 +117,12 @@ func handleExecuteQuery(rw http.ResponseWriter, req *http.Request) {
geo = &geoData{request.Geo.Latitude, request.Geo.Longitude} geo = &geoData{request.Geo.Latitude, request.Geo.Longitude}
} }
allEntries := getRecords(queryContext{geo, request.Profile, request.WalkingDist}) allEntries, err := fetchRecords(db, queryContext{geo, request.Profile, request.WalkingDist})
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
features := fixFeatures(request.Features) features := fixFeatures(request.Features)
modes := fixModes(request.Modes) modes := fixModes(request.Modes)
@ -241,13 +246,13 @@ func handleAddCategory(rw http.ResponseWriter, req *http.Request) {
return return
} }
affectedRows, err := result.RowsAffected() rows, err := result.RowsAffected()
if err != nil { if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError) http.Error(rw, err.Error(), http.StatusInternalServerError)
return return
} }
response.Success = affectedRows > 0 response.Success = rows > 0
response.Id = int(insertId) response.Id = int(insertId)
} }
@ -315,6 +320,7 @@ func handleAccessReview(rw http.ResponseWriter, req *http.Request) {
} }
if rowsAffected == 0 || len(request.Profile) == 0 { if rowsAffected == 0 || len(request.Profile) == 0 {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return return
} }

144
util.go
View File

@ -23,10 +23,9 @@
package main package main
import ( import (
"log" "database/sql"
"math" "math"
"strconv" "strconv"
"sync"
"github.com/kellydunn/golang-geo" "github.com/kellydunn/golang-geo"
) )
@ -69,7 +68,7 @@ func fixModes(modes map[string]string) map[string]modeType {
return fixedModes return fixedModes
} }
func similarity(features1 map[string]float64, features2 map[string]float64) float64 { func semanticSimilarity(features1 map[string]float64, features2 map[string]float64) float64 {
var result float64 var result float64
for key, value1 := range features1 { for key, value1 := range features1 {
@ -81,7 +80,7 @@ func similarity(features1 map[string]float64, features2 map[string]float64) floa
return result return result
} }
func compare(features1 map[string]float64, features2 map[string]float64, modes map[string]modeType) float64 { func semanticCompare(features1 map[string]float64, features2 map[string]float64, modes map[string]modeType) float64 {
var result float64 var result float64
for key, value1 := range features1 { for key, value1 := range features1 {
@ -92,8 +91,6 @@ func compare(features1 map[string]float64, features2 map[string]float64, modes m
result += 1 - math.Abs(value1-value2) result += 1 - math.Abs(value1-value2)
case modeTypeProd: case modeTypeProd:
result += value1 * value2 result += value1 * value2
default:
log.Fatal("unsupported compare mode")
} }
} }
@ -102,7 +99,7 @@ func compare(features1 map[string]float64, features2 map[string]float64, modes m
func walkMatches(entries []record, features map[string]float64, modes map[string]modeType, minScore float64, callback func(record, float64)) { func walkMatches(entries []record, features map[string]float64, modes map[string]modeType, minScore float64, callback func(record, float64)) {
for _, entry := range entries { for _, entry := range entries {
if score := compare(features, entry.features, modes); score >= minScore { if score := semanticCompare(features, entry.features, modes); score >= minScore {
callback(entry, score) callback(entry, score)
} }
} }
@ -163,7 +160,7 @@ func project(entries []record, features map[string]float64, modes map[string]mod
return projections return projections
} }
func computeRecordsGeo(entries []record, context queryContext) { func computeRecordGeo(entries []record, context queryContext) {
distUserMin := math.MaxFloat64 distUserMin := math.MaxFloat64
distUserMax := 0.0 distUserMax := 0.0
@ -199,91 +196,74 @@ func computeRecordsGeo(entries []record, context queryContext) {
} }
} }
func computeRecordCompat(entry *record, context queryContext, wg *sync.WaitGroup) { func computeRecordCompat(db *sql.DB, entries []record, context queryContext) error {
defer wg.Done() for i := range entries {
entry := &entries[i]
historyRows, err := db.Query("SELECT id FROM history WHERE reviewId = (?)", entry.Id) historyRows, err := db.Query("SELECT id FROM history WHERE reviewId = (?)", entry.Id)
if err != nil {
log.Fatal(err)
}
defer historyRows.Close()
var (
groupSum float64
groupCount int
)
for historyRows.Next() {
var historyId int
if err := historyRows.Scan(&historyId); err != nil {
log.Fatal(err)
}
groupRows, err := db.Query("SELECT categoryId, categoryValue FROM historyGroups WHERE historyId = (?)", historyId)
if err != nil { if err != nil {
log.Fatal(err) return err
} }
defer groupRows.Close() defer historyRows.Close()
recordProfile := make(map[string]float64) var (
for groupRows.Next() { groupSum float64
var ( groupCount int
categoryId int )
categoryValue float64
)
if err := groupRows.Scan(&categoryId, &categoryValue); err != nil { for historyRows.Next() {
log.Fatal(err) var historyId int
if err := historyRows.Scan(&historyId); err != nil {
return err
} }
recordProfile[strconv.Itoa(categoryId)] = categoryValue groupRows, err := db.Query("SELECT categoryId, categoryValue FROM historyGroups WHERE historyId = (?)", historyId)
if err != nil {
return err
}
defer groupRows.Close()
recordProfile := make(map[string]float64)
for groupRows.Next() {
var (
categoryId int
categoryValue float64
)
if err := groupRows.Scan(&categoryId, &categoryValue); err != nil {
return err
}
recordProfile[strconv.Itoa(categoryId)] = categoryValue
}
if err := groupRows.Err(); err != nil {
return err
}
groupSum += semanticSimilarity(recordProfile, context.profile)
groupCount++
} }
if err := groupRows.Err(); err != nil { if err := historyRows.Err(); err != nil {
log.Fatal(err) return err
} }
groupSum += similarity(recordProfile, context.profile) if groupCount > 0 {
groupCount++ entry.Compatibility = groupSum / float64(groupCount)
} }
if err := historyRows.Err(); err != nil {
log.Fatal(err)
} }
if groupCount > 0 { return nil
entry.Compatibility = groupSum / float64(groupCount)
}
} }
func computeRecordsCompat(entries []record, context queryContext) { func fetchRecords(db *sql.DB, context queryContext) ([]record, error) {
count := len(entries) rows, err := db.Query("SELECT name, url, delicious, accommodating, affordable, atmospheric, latitude, longitude, closestStnDist, closestStnName, accessCount, id FROM reviews")
limit := 32
for i := 0; i < count; i += limit {
batch := count - i
if batch > limit {
batch = limit
}
var wg sync.WaitGroup
wg.Add(batch)
for j := 0; j < batch; j++ {
go computeRecordCompat(&entries[i+j], context, &wg)
}
wg.Wait()
}
}
func getRecords(context queryContext) []record {
recordRows, err := db.Query("SELECT name, url, delicious, accommodating, affordable, atmospheric, latitude, longitude, closestStnDist, closestStnName, accessCount, id FROM reviews")
if err != nil { if err != nil {
log.Fatal(err) return nil, err
} }
defer recordRows.Close() defer rows.Close()
var entries []record var entries []record
for recordRows.Next() { for rows.Next() {
var ( var (
name, url, closestStn string name, url, closestStn string
delicious, accommodating, affordable, atmospheric float64 delicious, accommodating, affordable, atmospheric float64
@ -291,7 +271,7 @@ func getRecords(context queryContext) []record {
accessCount, id int accessCount, id int
) )
recordRows.Scan( rows.Scan(
&name, &name,
&url, &url,
&delicious, &delicious,
@ -322,12 +302,14 @@ func getRecords(context queryContext) []record {
entries = append(entries, entry) entries = append(entries, entry)
} }
if err := recordRows.Err(); err != nil { if err := rows.Err(); err != nil {
log.Fatal(err) return nil, err
} }
computeRecordsCompat(entries, context) computeRecordGeo(entries, context)
computeRecordsGeo(entries, context) if err := computeRecordCompat(db, entries, context); err != nil {
return nil, err
}
return entries return entries, nil
} }