📄 vectorSearch.ts  •  4179 bytes
/**
 * CmdCode 向量记忆系统 - 向量搜索(简化版)
 * 
 * 注:由于 Bun sqlite 暂不支持虚拟表,向量搜索简化为:
 * - 存储 embedding blob 到数据库
 * - 搜索时使用 FTS5 关键词匹配 + 内存向量计算
 */
import { getDb } from './database'
import { packEmbedding } from './utils'
import { getEmbedding, clearCache } from './embedding'
import { t } from '../i18n.js'
import type { MessageInfo } from './sessionStore'

// 内存中的向量索引(用于快速相似度搜索)
const memoryIndex = new Map<number, number[]>()
const MAX_MEMORY_INDEX = 1000

/** 存储消息向量 */
export async function storeMessageEmbedding(msgId: number, content: string): Promise<boolean> {
  try {
    const embedding = await getEmbedding(content)
    const db = getDb()
    const vectorBlob = packEmbedding(embedding)
    const textHash = require('crypto').createHash('sha256').update(content).digest('hex')

    db.run(`
      INSERT OR REPLACE INTO message_embeddings (msg_id, embedding, text_hash, created_at)
      VALUES (?, ?, ?, datetime('now'))
    `, msgId, vectorBlob, textHash)

    // 同时存入内存索引
    if (memoryIndex.size >= MAX_MEMORY_INDEX) {
      const firstKey = memoryIndex.keys().next().value
      memoryIndex.delete(firstKey)
    }
    memoryIndex.set(msgId, embedding)

    return true
  } catch (e) {
    console.error(t('error.save_vector'), e)
    return false
  }
}

/** 余弦相似度计算 */
function cosineSimilarity(a: number[], b: number[]): number {
  let dot = 0, normA = 0, normB = 0
  for (let i = 0; i < a.length; i++) {
    dot += a[i] * b[i]
    normA += a[i] * a[i]
    normB += b[i] * b[i]
  }
  return dot / (Math.sqrt(normA) * Math.sqrt(normB) + 1e-10)
}

/** 向量相似度搜索(内存计算) */
export async function searchVectors(query: string, sessionId?: string, limit = 20): Promise<(MessageInfo & { distance: number })[]> {
  try {
    const queryEmbedding = await getEmbedding(query)
    const db = getDb()

    // 获取该会话的所有消息及其向量
    let sql: string
    let params: any[]

    if (sessionId) {
      sql = `
        SELECT m.*, e.embedding FROM messages m
        LEFT JOIN message_embeddings e ON m.id = e.msg_id
        WHERE m.session_id = ? AND e.embedding IS NOT NULL
      `
      params = [sessionId]
    } else {
      sql = `
        SELECT m.*, e.embedding FROM messages m
        LEFT JOIN message_embeddings e ON m.id = e.msg_id
        WHERE e.embedding IS NOT NULL
      `
      params = []
    }

    const rows = db.query(sql).all(...params) as any[]

    // 计算相似度并排序
    const results: (MessageInfo & { distance: number })[] = []
    for (const row of rows) {
      if (row.embedding) {
        const embedding = Array.from(new Float32Array(row.embedding.buffer))
        const similarity = cosineSimilarity(queryEmbedding, embedding)
        results.push({
          id: row.id,
          session_id: row.session_id,
          role: row.role,
          content: row.content,
          created_at: row.created_at,
          distance: 1 - similarity // 转为距离(越小越相似)
        })
      }
    }

    // 排序并返回前N个
    results.sort((a, b) => a.distance - b.distance)
    return results.slice(0, limit)
  } catch (e) {
    console.error(t('error.vector_search'), e)
    return []
  }
}

/** 删除消息向量 */
export function deleteMessageVector(msgId: number): boolean {
  const db = getDb()
  const result = db.run('DELETE FROM message_embeddings WHERE msg_id = ?', msgId)
  memoryIndex.delete(msgId)
  // P3 #2.7: 删除向量后清除embedding缓存,防止相同内容重写时命中旧缓存
  clearCache()
  return result.changes > 0
}

/** 获取向量数量 */
export function getVectorCount(): number {
  const db = getDb()
  const result = db.query('SELECT COUNT(*) as cnt FROM message_embeddings').get() as { cnt: number }
  return result?.cnt || 0
}

/** 检查消息是否有向量 */
export function hasVector(msgId: number): boolean {
  const db = getDb()
  const result = db.query('SELECT 1 FROM message_embeddings WHERE msg_id = ?').get(msgId)
  return !!result
}