615 lines
20 KiB
Python
615 lines
20 KiB
Python
"""
|
|
Lemontropia Suite - Icon Matcher Module
|
|
Icon similarity matching using multiple algorithms.
|
|
Supports perceptual hashing, template matching, and feature-based matching.
|
|
"""
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import logging
|
|
import json
|
|
from pathlib import Path
|
|
from dataclasses import dataclass, asdict
|
|
from typing import Optional, List, Dict, Tuple, Any
|
|
import sqlite3
|
|
import pickle
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class MatchResult:
|
|
"""Icon match result."""
|
|
item_name: str
|
|
confidence: float
|
|
match_method: str
|
|
item_id: Optional[str] = None
|
|
category: Optional[str] = None
|
|
metadata: Dict[str, Any] = None
|
|
|
|
def __post_init__(self):
|
|
if self.metadata is None:
|
|
self.metadata = {}
|
|
|
|
|
|
class PerceptualHash:
|
|
"""Perceptual hash implementation for icon matching."""
|
|
|
|
@staticmethod
|
|
def average_hash(image: np.ndarray, hash_size: int = 16) -> str:
|
|
"""Compute average hash (aHash)."""
|
|
# Convert to grayscale
|
|
if len(image.shape) == 3:
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
gray = image
|
|
|
|
# Resize
|
|
resized = cv2.resize(gray, (hash_size, hash_size), interpolation=cv2.INTER_AREA)
|
|
|
|
# Compute average
|
|
avg = resized.mean()
|
|
|
|
# Create hash
|
|
hash_bits = (resized > avg).flatten()
|
|
return ''.join(['1' if b else '0' for b in hash_bits])
|
|
|
|
@staticmethod
|
|
def difference_hash(image: np.ndarray, hash_size: int = 16) -> str:
|
|
"""Compute difference hash (dHash)."""
|
|
if len(image.shape) == 3:
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
gray = image
|
|
|
|
# Resize (hash_size+1 for horizontal differences)
|
|
resized = cv2.resize(gray, (hash_size + 1, hash_size), interpolation=cv2.INTER_AREA)
|
|
|
|
# Compute differences
|
|
diff = resized[:, 1:] > resized[:, :-1]
|
|
return ''.join(['1' if b else '0' for b in diff.flatten()])
|
|
|
|
@staticmethod
|
|
def wavelet_hash(image: np.ndarray, hash_size: int = 16) -> str:
|
|
"""Compute wavelet hash (wHash) using Haar wavelet."""
|
|
try:
|
|
import pywt
|
|
except ImportError:
|
|
logger.debug("PyWavelets not available, falling back to average hash")
|
|
return PerceptualHash.average_hash(image, hash_size)
|
|
|
|
if len(image.shape) == 3:
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
gray = image
|
|
|
|
# Resize to power of 2
|
|
size = 2 ** (hash_size - 1).bit_length()
|
|
resized = cv2.resize(gray, (size, size), interpolation=cv2.INTER_AREA)
|
|
|
|
# Apply Haar wavelet transform
|
|
coeffs = pywt.dwt2(resized, 'haar')
|
|
cA, (cH, cV, cD) = coeffs
|
|
|
|
# Use approximation coefficients
|
|
avg = cA.mean()
|
|
hash_bits = (cA > avg).flatten()
|
|
return ''.join(['1' if b else '0' for b in hash_bits])
|
|
|
|
@staticmethod
|
|
def hamming_distance(hash1: str, hash2: str) -> int:
|
|
"""Calculate Hamming distance between two hashes."""
|
|
if len(hash1) != len(hash2):
|
|
raise ValueError("Hashes must be same length")
|
|
return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
|
|
|
|
@staticmethod
|
|
def similarity(hash1: str, hash2: str) -> float:
|
|
"""Calculate similarity between 0 and 1."""
|
|
distance = PerceptualHash.hamming_distance(hash1, hash2)
|
|
max_distance = len(hash1)
|
|
return 1.0 - (distance / max_distance)
|
|
|
|
|
|
class FeatureMatcher:
|
|
"""Feature-based icon matching using ORB/SIFT."""
|
|
|
|
def __init__(self):
|
|
self.orb = cv2.ORB_create(nfeatures=500)
|
|
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
|
|
|
|
def extract_features(self, image: np.ndarray) -> Tuple[List, np.ndarray]:
|
|
"""Extract ORB features from image."""
|
|
if len(image.shape) == 3:
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
else:
|
|
gray = image
|
|
|
|
keypoints, descriptors = self.orb.detectAndCompute(gray, None)
|
|
return keypoints, descriptors
|
|
|
|
def match_features(self, desc1: np.ndarray, desc2: np.ndarray,
|
|
threshold: float = 0.7) -> float:
|
|
"""
|
|
Match features between two descriptors.
|
|
|
|
Returns confidence score (0-1).
|
|
"""
|
|
if desc1 is None or desc2 is None:
|
|
return 0.0
|
|
|
|
try:
|
|
matches = self.matcher.match(desc1, desc2)
|
|
matches = sorted(matches, key=lambda x: x.distance)
|
|
|
|
# Calculate match ratio
|
|
if len(matches) < 4:
|
|
return 0.0
|
|
|
|
# Good matches have distance below threshold
|
|
good_matches = [m for m in matches if m.distance < 50]
|
|
|
|
if not good_matches:
|
|
return 0.0
|
|
|
|
# Score based on number of good matches vs minimum needed
|
|
score = min(len(good_matches) / 20, 1.0) # Normalize to 20 matches
|
|
return score
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Feature matching failed: {e}")
|
|
return 0.0
|
|
|
|
|
|
class TemplateMatcher:
|
|
"""Template matching for icons."""
|
|
|
|
@staticmethod
|
|
def match(template: np.ndarray, image: np.ndarray,
|
|
methods: List[int] = None) -> float:
|
|
"""
|
|
Match template to image using multiple methods.
|
|
|
|
Returns best confidence score.
|
|
"""
|
|
if methods is None:
|
|
methods = [
|
|
cv2.TM_CCOEFF_NORMED,
|
|
cv2.TM_CCORR_NORMED,
|
|
cv2.TM_SQDIFF_NORMED
|
|
]
|
|
|
|
# Ensure same size
|
|
h, w = template.shape[:2]
|
|
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
|
|
|
|
best_score = 0.0
|
|
|
|
for method in methods:
|
|
try:
|
|
result = cv2.matchTemplate(image, template, method)
|
|
_, max_val, _, _ = cv2.minMaxLoc(result)
|
|
|
|
# Normalize SQDIFF (lower is better)
|
|
if method == cv2.TM_SQDIFF_NORMED:
|
|
max_val = 1.0 - max_val
|
|
|
|
best_score = max(best_score, max_val)
|
|
except Exception as e:
|
|
logger.debug(f"Template matching failed: {e}")
|
|
continue
|
|
|
|
return best_score
|
|
|
|
|
|
class IconDatabase:
|
|
"""Database for storing and retrieving icon hashes."""
|
|
|
|
def __init__(self, db_path: Optional[Path] = None):
|
|
self.db_path = db_path or Path.home() / ".lemontropia" / "icon_database.db"
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._init_database()
|
|
|
|
def _init_database(self):
|
|
"""Initialize SQLite database."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS icons (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
item_name TEXT NOT NULL,
|
|
item_id TEXT,
|
|
category TEXT,
|
|
avg_hash TEXT,
|
|
diff_hash TEXT,
|
|
wavelet_hash TEXT,
|
|
features BLOB,
|
|
metadata TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
|
|
cursor.execute('''
|
|
CREATE INDEX IF NOT EXISTS idx_avg_hash ON icons(avg_hash)
|
|
''')
|
|
|
|
cursor.execute('''
|
|
CREATE INDEX IF NOT EXISTS idx_item_name ON icons(item_name)
|
|
''')
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def add_icon(self, item_name: str, image: np.ndarray,
|
|
item_id: Optional[str] = None,
|
|
category: Optional[str] = None,
|
|
metadata: Optional[Dict] = None) -> bool:
|
|
"""Add icon to database."""
|
|
try:
|
|
# Compute hashes
|
|
avg_hash = PerceptualHash.average_hash(image)
|
|
diff_hash = PerceptualHash.difference_hash(image)
|
|
wavelet_hash = PerceptualHash.wavelet_hash(image)
|
|
|
|
# Extract features
|
|
feature_matcher = FeatureMatcher()
|
|
_, features = feature_matcher.extract_features(image)
|
|
features_blob = pickle.dumps(features) if features is not None else None
|
|
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute('''
|
|
INSERT INTO icons
|
|
(item_name, item_id, category, avg_hash, diff_hash, wavelet_hash, features, metadata)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
''', (
|
|
item_name, item_id, category,
|
|
avg_hash, diff_hash, wavelet_hash,
|
|
features_blob,
|
|
json.dumps(metadata) if metadata else None
|
|
))
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
logger.debug(f"Added icon to database: {item_name}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add icon: {e}")
|
|
return False
|
|
|
|
def find_by_hash(self, avg_hash: str, max_distance: int = 10) -> List[Tuple[str, float, Dict]]:
|
|
"""Find icons by hash similarity."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute('SELECT item_name, avg_hash, diff_hash, item_id, category, metadata FROM icons')
|
|
results = []
|
|
|
|
for row in cursor.fetchall():
|
|
item_name, db_avg_hash, db_diff_hash, item_id, category, metadata_json = row
|
|
|
|
# Check average hash similarity
|
|
distance = PerceptualHash.hamming_distance(avg_hash, db_avg_hash)
|
|
|
|
if distance <= max_distance:
|
|
similarity = 1.0 - (distance / len(avg_hash))
|
|
metadata = json.loads(metadata_json) if metadata_json else {}
|
|
results.append((item_name, similarity, {
|
|
'item_id': item_id,
|
|
'category': category,
|
|
'metadata': metadata
|
|
}))
|
|
|
|
conn.close()
|
|
|
|
# Sort by similarity
|
|
results.sort(key=lambda x: x[1], reverse=True)
|
|
return results
|
|
|
|
def get_all_icons(self) -> List[Dict]:
|
|
"""Get all icons from database."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute('''
|
|
SELECT item_name, item_id, category, avg_hash, metadata
|
|
FROM icons
|
|
''')
|
|
|
|
results = []
|
|
for row in cursor.fetchall():
|
|
results.append({
|
|
'item_name': row[0],
|
|
'item_id': row[1],
|
|
'category': row[2],
|
|
'avg_hash': row[3],
|
|
'metadata': json.loads(row[4]) if row[4] else {}
|
|
})
|
|
|
|
conn.close()
|
|
return results
|
|
|
|
def get_icon_count(self) -> int:
|
|
"""Get total number of icons in database."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
cursor.execute('SELECT COUNT(*) FROM icons')
|
|
count = cursor.fetchone()[0]
|
|
conn.close()
|
|
return count
|
|
|
|
def delete_icon(self, item_name: str) -> bool:
|
|
"""Delete icon from database."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
cursor.execute('DELETE FROM icons WHERE item_name = ?', (item_name,))
|
|
conn.commit()
|
|
deleted = cursor.rowcount > 0
|
|
conn.close()
|
|
return deleted
|
|
|
|
|
|
class IconMatcher:
|
|
"""
|
|
Main icon matching interface.
|
|
Combines multiple matching algorithms for best results.
|
|
"""
|
|
|
|
# Confidence thresholds
|
|
CONFIDENCE_HIGH = 0.85
|
|
CONFIDENCE_MEDIUM = 0.70
|
|
CONFIDENCE_LOW = 0.50
|
|
|
|
def __init__(self, database_path: Optional[Path] = None,
|
|
icons_dir: Optional[Path] = None):
|
|
"""
|
|
Initialize icon matcher.
|
|
|
|
Args:
|
|
database_path: Path to icon database
|
|
icons_dir: Directory containing icon images for matching
|
|
"""
|
|
self.database = IconDatabase(database_path)
|
|
self.icons_dir = icons_dir or Path.home() / ".lemontropia" / "icons"
|
|
self.feature_matcher = FeatureMatcher()
|
|
|
|
# Cache for loaded icons
|
|
self._icon_cache: Dict[str, np.ndarray] = {}
|
|
|
|
def match_icon(self, image: np.ndarray,
|
|
match_methods: List[str] = None) -> Optional[MatchResult]:
|
|
"""
|
|
Match an icon image against the database.
|
|
|
|
Args:
|
|
image: Icon image (numpy array)
|
|
match_methods: List of methods to use ('hash', 'feature', 'template')
|
|
|
|
Returns:
|
|
MatchResult if match found, None otherwise
|
|
"""
|
|
if match_methods is None:
|
|
match_methods = ['hash', 'feature', 'template']
|
|
|
|
results = []
|
|
|
|
# Method 1: Perceptual Hash Matching
|
|
if 'hash' in match_methods:
|
|
hash_result = self._match_by_hash(image)
|
|
if hash_result:
|
|
results.append(hash_result)
|
|
|
|
# Method 2: Feature Matching
|
|
if 'feature' in match_methods:
|
|
feature_result = self._match_by_features(image)
|
|
if feature_result:
|
|
results.append(feature_result)
|
|
|
|
# Method 3: Template Matching
|
|
if 'template' in match_methods:
|
|
template_result = self._match_by_template(image)
|
|
if template_result:
|
|
results.append(template_result)
|
|
|
|
if not results:
|
|
return None
|
|
|
|
# Return best match
|
|
best = max(results, key=lambda x: x.confidence)
|
|
return best
|
|
|
|
def _match_by_hash(self, image: np.ndarray) -> Optional[MatchResult]:
|
|
"""Match using perceptual hashing."""
|
|
avg_hash = PerceptualHash.average_hash(image)
|
|
|
|
# Query database
|
|
matches = self.database.find_by_hash(avg_hash, max_distance=15)
|
|
|
|
if not matches:
|
|
return None
|
|
|
|
best_match = matches[0]
|
|
item_name, similarity, meta = best_match
|
|
|
|
if similarity >= self.CONFIDENCE_LOW:
|
|
return MatchResult(
|
|
item_name=item_name,
|
|
confidence=similarity,
|
|
match_method='hash',
|
|
item_id=meta.get('item_id'),
|
|
category=meta.get('category'),
|
|
metadata=meta.get('metadata', {})
|
|
)
|
|
|
|
return None
|
|
|
|
def _match_by_features(self, image: np.ndarray) -> Optional[MatchResult]:
|
|
"""Match using ORB features."""
|
|
_, query_desc = self.feature_matcher.extract_features(image)
|
|
|
|
if query_desc is None:
|
|
return None
|
|
|
|
# Get all icons with features from database
|
|
conn = sqlite3.connect(str(self.database.db_path))
|
|
cursor = conn.cursor()
|
|
cursor.execute('''
|
|
SELECT item_name, features, item_id, category, metadata
|
|
FROM icons WHERE features IS NOT NULL
|
|
''')
|
|
|
|
best_match = None
|
|
best_score = 0.0
|
|
best_meta = {}
|
|
|
|
for row in cursor.fetchall():
|
|
item_name, features_blob, item_id, category, metadata_json = row
|
|
db_desc = pickle.loads(features_blob)
|
|
|
|
score = self.feature_matcher.match_features(query_desc, db_desc)
|
|
|
|
if score > best_score:
|
|
best_score = score
|
|
best_match = item_name
|
|
best_meta = {
|
|
'item_id': item_id,
|
|
'category': category,
|
|
'metadata': json.loads(metadata_json) if metadata_json else {}
|
|
}
|
|
|
|
conn.close()
|
|
|
|
if best_match and best_score >= self.CONFIDENCE_LOW:
|
|
return MatchResult(
|
|
item_name=best_match,
|
|
confidence=best_score,
|
|
match_method='feature',
|
|
item_id=best_meta.get('item_id'),
|
|
category=best_meta.get('category'),
|
|
metadata=best_meta.get('metadata', {})
|
|
)
|
|
|
|
return None
|
|
|
|
def _match_by_template(self, image: np.ndarray) -> Optional[MatchResult]:
|
|
"""Match using template matching against icon files."""
|
|
if not self.icons_dir.exists():
|
|
return None
|
|
|
|
# Resize query to standard size
|
|
standard_size = (64, 64)
|
|
query_resized = cv2.resize(image, standard_size, interpolation=cv2.INTER_AREA)
|
|
|
|
best_match = None
|
|
best_score = 0.0
|
|
|
|
for icon_file in self.icons_dir.glob("**/*.png"):
|
|
try:
|
|
template = cv2.imread(str(icon_file), cv2.IMREAD_COLOR)
|
|
if template is None:
|
|
continue
|
|
|
|
template_resized = cv2.resize(template, standard_size, interpolation=cv2.INTER_AREA)
|
|
|
|
score = TemplateMatcher.match(query_resized, template_resized)
|
|
|
|
if score > best_score:
|
|
best_score = score
|
|
best_match = icon_file.stem
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Template matching failed for {icon_file}: {e}")
|
|
continue
|
|
|
|
if best_match and best_score >= self.CONFIDENCE_MEDIUM:
|
|
return MatchResult(
|
|
item_name=best_match,
|
|
confidence=best_score,
|
|
match_method='template'
|
|
)
|
|
|
|
return None
|
|
|
|
def add_icon_to_database(self, item_name: str, image: np.ndarray,
|
|
item_id: Optional[str] = None,
|
|
category: Optional[str] = None,
|
|
metadata: Optional[Dict] = None) -> bool:
|
|
"""Add a new icon to the database."""
|
|
return self.database.add_icon(item_name, image, item_id, category, metadata)
|
|
|
|
def batch_add_icons(self, icons_dir: Path,
|
|
category: Optional[str] = None) -> Tuple[int, int]:
|
|
"""
|
|
Batch add icons from directory.
|
|
|
|
Returns:
|
|
Tuple of (success_count, fail_count)
|
|
"""
|
|
success = 0
|
|
failed = 0
|
|
|
|
for icon_file in icons_dir.glob("**/*.png"):
|
|
try:
|
|
image = cv2.imread(str(icon_file), cv2.IMREAD_COLOR)
|
|
if image is None:
|
|
failed += 1
|
|
continue
|
|
|
|
item_name = icon_file.stem.replace('_', ' ').title()
|
|
|
|
if self.add_icon_to_database(item_name, image, category=category):
|
|
success += 1
|
|
else:
|
|
failed += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add icon {icon_file}: {e}")
|
|
failed += 1
|
|
|
|
logger.info(f"Batch add complete: {success} success, {failed} failed")
|
|
return success, failed
|
|
|
|
def get_database_stats(self) -> Dict[str, Any]:
|
|
"""Get database statistics."""
|
|
return {
|
|
'total_icons': self.database.get_icon_count(),
|
|
'database_path': str(self.database.db_path),
|
|
'icons_directory': str(self.icons_dir)
|
|
}
|
|
|
|
def find_similar_icons(self, image: np.ndarray,
|
|
top_k: int = 5) -> List[MatchResult]:
|
|
"""Find top-k similar icons."""
|
|
avg_hash = PerceptualHash.average_hash(image)
|
|
|
|
# Get all matches
|
|
matches = self.database.find_by_hash(avg_hash, max_distance=20)
|
|
|
|
results = []
|
|
for item_name, similarity, meta in matches[:top_k]:
|
|
results.append(MatchResult(
|
|
item_name=item_name,
|
|
confidence=similarity,
|
|
match_method='hash',
|
|
item_id=meta.get('item_id'),
|
|
category=meta.get('category'),
|
|
metadata=meta.get('metadata', {})
|
|
))
|
|
|
|
return results
|
|
|
|
|
|
# Export main classes
|
|
__all__ = [
|
|
'IconMatcher',
|
|
'MatchResult',
|
|
'PerceptualHash',
|
|
'FeatureMatcher',
|
|
'TemplateMatcher',
|
|
'IconDatabase'
|
|
]
|