Lemontropia-Suite/modules/icon_matcher.py

615 lines
20 KiB
Python

"""
Lemontropia Suite - Icon Matcher Module
Icon similarity matching using multiple algorithms.
Supports perceptual hashing, template matching, and feature-based matching.
"""
import cv2
import numpy as np
import logging
import json
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional, List, Dict, Tuple, Any
import sqlite3
import pickle
logger = logging.getLogger(__name__)
@dataclass
class MatchResult:
"""Icon match result."""
item_name: str
confidence: float
match_method: str
item_id: Optional[str] = None
category: Optional[str] = None
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
class PerceptualHash:
"""Perceptual hash implementation for icon matching."""
@staticmethod
def average_hash(image: np.ndarray, hash_size: int = 16) -> str:
"""Compute average hash (aHash)."""
# Convert to grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
# Resize
resized = cv2.resize(gray, (hash_size, hash_size), interpolation=cv2.INTER_AREA)
# Compute average
avg = resized.mean()
# Create hash
hash_bits = (resized > avg).flatten()
return ''.join(['1' if b else '0' for b in hash_bits])
@staticmethod
def difference_hash(image: np.ndarray, hash_size: int = 16) -> str:
"""Compute difference hash (dHash)."""
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
# Resize (hash_size+1 for horizontal differences)
resized = cv2.resize(gray, (hash_size + 1, hash_size), interpolation=cv2.INTER_AREA)
# Compute differences
diff = resized[:, 1:] > resized[:, :-1]
return ''.join(['1' if b else '0' for b in diff.flatten()])
@staticmethod
def wavelet_hash(image: np.ndarray, hash_size: int = 16) -> str:
"""Compute wavelet hash (wHash) using Haar wavelet."""
try:
import pywt
except ImportError:
logger.debug("PyWavelets not available, falling back to average hash")
return PerceptualHash.average_hash(image, hash_size)
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
# Resize to power of 2
size = 2 ** (hash_size - 1).bit_length()
resized = cv2.resize(gray, (size, size), interpolation=cv2.INTER_AREA)
# Apply Haar wavelet transform
coeffs = pywt.dwt2(resized, 'haar')
cA, (cH, cV, cD) = coeffs
# Use approximation coefficients
avg = cA.mean()
hash_bits = (cA > avg).flatten()
return ''.join(['1' if b else '0' for b in hash_bits])
@staticmethod
def hamming_distance(hash1: str, hash2: str) -> int:
"""Calculate Hamming distance between two hashes."""
if len(hash1) != len(hash2):
raise ValueError("Hashes must be same length")
return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
@staticmethod
def similarity(hash1: str, hash2: str) -> float:
"""Calculate similarity between 0 and 1."""
distance = PerceptualHash.hamming_distance(hash1, hash2)
max_distance = len(hash1)
return 1.0 - (distance / max_distance)
class FeatureMatcher:
"""Feature-based icon matching using ORB/SIFT."""
def __init__(self):
self.orb = cv2.ORB_create(nfeatures=500)
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
def extract_features(self, image: np.ndarray) -> Tuple[List, np.ndarray]:
"""Extract ORB features from image."""
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
keypoints, descriptors = self.orb.detectAndCompute(gray, None)
return keypoints, descriptors
def match_features(self, desc1: np.ndarray, desc2: np.ndarray,
threshold: float = 0.7) -> float:
"""
Match features between two descriptors.
Returns confidence score (0-1).
"""
if desc1 is None or desc2 is None:
return 0.0
try:
matches = self.matcher.match(desc1, desc2)
matches = sorted(matches, key=lambda x: x.distance)
# Calculate match ratio
if len(matches) < 4:
return 0.0
# Good matches have distance below threshold
good_matches = [m for m in matches if m.distance < 50]
if not good_matches:
return 0.0
# Score based on number of good matches vs minimum needed
score = min(len(good_matches) / 20, 1.0) # Normalize to 20 matches
return score
except Exception as e:
logger.debug(f"Feature matching failed: {e}")
return 0.0
class TemplateMatcher:
"""Template matching for icons."""
@staticmethod
def match(template: np.ndarray, image: np.ndarray,
methods: List[int] = None) -> float:
"""
Match template to image using multiple methods.
Returns best confidence score.
"""
if methods is None:
methods = [
cv2.TM_CCOEFF_NORMED,
cv2.TM_CCORR_NORMED,
cv2.TM_SQDIFF_NORMED
]
# Ensure same size
h, w = template.shape[:2]
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
best_score = 0.0
for method in methods:
try:
result = cv2.matchTemplate(image, template, method)
_, max_val, _, _ = cv2.minMaxLoc(result)
# Normalize SQDIFF (lower is better)
if method == cv2.TM_SQDIFF_NORMED:
max_val = 1.0 - max_val
best_score = max(best_score, max_val)
except Exception as e:
logger.debug(f"Template matching failed: {e}")
continue
return best_score
class IconDatabase:
"""Database for storing and retrieving icon hashes."""
def __init__(self, db_path: Optional[Path] = None):
self.db_path = db_path or Path.home() / ".lemontropia" / "icon_database.db"
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_database()
def _init_database(self):
"""Initialize SQLite database."""
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS icons (
id INTEGER PRIMARY KEY AUTOINCREMENT,
item_name TEXT NOT NULL,
item_id TEXT,
category TEXT,
avg_hash TEXT,
diff_hash TEXT,
wavelet_hash TEXT,
features BLOB,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_avg_hash ON icons(avg_hash)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_item_name ON icons(item_name)
''')
conn.commit()
conn.close()
def add_icon(self, item_name: str, image: np.ndarray,
item_id: Optional[str] = None,
category: Optional[str] = None,
metadata: Optional[Dict] = None) -> bool:
"""Add icon to database."""
try:
# Compute hashes
avg_hash = PerceptualHash.average_hash(image)
diff_hash = PerceptualHash.difference_hash(image)
wavelet_hash = PerceptualHash.wavelet_hash(image)
# Extract features
feature_matcher = FeatureMatcher()
_, features = feature_matcher.extract_features(image)
features_blob = pickle.dumps(features) if features is not None else None
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute('''
INSERT INTO icons
(item_name, item_id, category, avg_hash, diff_hash, wavelet_hash, features, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
item_name, item_id, category,
avg_hash, diff_hash, wavelet_hash,
features_blob,
json.dumps(metadata) if metadata else None
))
conn.commit()
conn.close()
logger.debug(f"Added icon to database: {item_name}")
return True
except Exception as e:
logger.error(f"Failed to add icon: {e}")
return False
def find_by_hash(self, avg_hash: str, max_distance: int = 10) -> List[Tuple[str, float, Dict]]:
"""Find icons by hash similarity."""
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute('SELECT item_name, avg_hash, diff_hash, item_id, category, metadata FROM icons')
results = []
for row in cursor.fetchall():
item_name, db_avg_hash, db_diff_hash, item_id, category, metadata_json = row
# Check average hash similarity
distance = PerceptualHash.hamming_distance(avg_hash, db_avg_hash)
if distance <= max_distance:
similarity = 1.0 - (distance / len(avg_hash))
metadata = json.loads(metadata_json) if metadata_json else {}
results.append((item_name, similarity, {
'item_id': item_id,
'category': category,
'metadata': metadata
}))
conn.close()
# Sort by similarity
results.sort(key=lambda x: x[1], reverse=True)
return results
def get_all_icons(self) -> List[Dict]:
"""Get all icons from database."""
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute('''
SELECT item_name, item_id, category, avg_hash, metadata
FROM icons
''')
results = []
for row in cursor.fetchall():
results.append({
'item_name': row[0],
'item_id': row[1],
'category': row[2],
'avg_hash': row[3],
'metadata': json.loads(row[4]) if row[4] else {}
})
conn.close()
return results
def get_icon_count(self) -> int:
"""Get total number of icons in database."""
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM icons')
count = cursor.fetchone()[0]
conn.close()
return count
def delete_icon(self, item_name: str) -> bool:
"""Delete icon from database."""
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute('DELETE FROM icons WHERE item_name = ?', (item_name,))
conn.commit()
deleted = cursor.rowcount > 0
conn.close()
return deleted
class IconMatcher:
"""
Main icon matching interface.
Combines multiple matching algorithms for best results.
"""
# Confidence thresholds
CONFIDENCE_HIGH = 0.85
CONFIDENCE_MEDIUM = 0.70
CONFIDENCE_LOW = 0.50
def __init__(self, database_path: Optional[Path] = None,
icons_dir: Optional[Path] = None):
"""
Initialize icon matcher.
Args:
database_path: Path to icon database
icons_dir: Directory containing icon images for matching
"""
self.database = IconDatabase(database_path)
self.icons_dir = icons_dir or Path.home() / ".lemontropia" / "icons"
self.feature_matcher = FeatureMatcher()
# Cache for loaded icons
self._icon_cache: Dict[str, np.ndarray] = {}
def match_icon(self, image: np.ndarray,
match_methods: List[str] = None) -> Optional[MatchResult]:
"""
Match an icon image against the database.
Args:
image: Icon image (numpy array)
match_methods: List of methods to use ('hash', 'feature', 'template')
Returns:
MatchResult if match found, None otherwise
"""
if match_methods is None:
match_methods = ['hash', 'feature', 'template']
results = []
# Method 1: Perceptual Hash Matching
if 'hash' in match_methods:
hash_result = self._match_by_hash(image)
if hash_result:
results.append(hash_result)
# Method 2: Feature Matching
if 'feature' in match_methods:
feature_result = self._match_by_features(image)
if feature_result:
results.append(feature_result)
# Method 3: Template Matching
if 'template' in match_methods:
template_result = self._match_by_template(image)
if template_result:
results.append(template_result)
if not results:
return None
# Return best match
best = max(results, key=lambda x: x.confidence)
return best
def _match_by_hash(self, image: np.ndarray) -> Optional[MatchResult]:
"""Match using perceptual hashing."""
avg_hash = PerceptualHash.average_hash(image)
# Query database
matches = self.database.find_by_hash(avg_hash, max_distance=15)
if not matches:
return None
best_match = matches[0]
item_name, similarity, meta = best_match
if similarity >= self.CONFIDENCE_LOW:
return MatchResult(
item_name=item_name,
confidence=similarity,
match_method='hash',
item_id=meta.get('item_id'),
category=meta.get('category'),
metadata=meta.get('metadata', {})
)
return None
def _match_by_features(self, image: np.ndarray) -> Optional[MatchResult]:
"""Match using ORB features."""
_, query_desc = self.feature_matcher.extract_features(image)
if query_desc is None:
return None
# Get all icons with features from database
conn = sqlite3.connect(str(self.database.db_path))
cursor = conn.cursor()
cursor.execute('''
SELECT item_name, features, item_id, category, metadata
FROM icons WHERE features IS NOT NULL
''')
best_match = None
best_score = 0.0
best_meta = {}
for row in cursor.fetchall():
item_name, features_blob, item_id, category, metadata_json = row
db_desc = pickle.loads(features_blob)
score = self.feature_matcher.match_features(query_desc, db_desc)
if score > best_score:
best_score = score
best_match = item_name
best_meta = {
'item_id': item_id,
'category': category,
'metadata': json.loads(metadata_json) if metadata_json else {}
}
conn.close()
if best_match and best_score >= self.CONFIDENCE_LOW:
return MatchResult(
item_name=best_match,
confidence=best_score,
match_method='feature',
item_id=best_meta.get('item_id'),
category=best_meta.get('category'),
metadata=best_meta.get('metadata', {})
)
return None
def _match_by_template(self, image: np.ndarray) -> Optional[MatchResult]:
"""Match using template matching against icon files."""
if not self.icons_dir.exists():
return None
# Resize query to standard size
standard_size = (64, 64)
query_resized = cv2.resize(image, standard_size, interpolation=cv2.INTER_AREA)
best_match = None
best_score = 0.0
for icon_file in self.icons_dir.glob("**/*.png"):
try:
template = cv2.imread(str(icon_file), cv2.IMREAD_COLOR)
if template is None:
continue
template_resized = cv2.resize(template, standard_size, interpolation=cv2.INTER_AREA)
score = TemplateMatcher.match(query_resized, template_resized)
if score > best_score:
best_score = score
best_match = icon_file.stem
except Exception as e:
logger.debug(f"Template matching failed for {icon_file}: {e}")
continue
if best_match and best_score >= self.CONFIDENCE_MEDIUM:
return MatchResult(
item_name=best_match,
confidence=best_score,
match_method='template'
)
return None
def add_icon_to_database(self, item_name: str, image: np.ndarray,
item_id: Optional[str] = None,
category: Optional[str] = None,
metadata: Optional[Dict] = None) -> bool:
"""Add a new icon to the database."""
return self.database.add_icon(item_name, image, item_id, category, metadata)
def batch_add_icons(self, icons_dir: Path,
category: Optional[str] = None) -> Tuple[int, int]:
"""
Batch add icons from directory.
Returns:
Tuple of (success_count, fail_count)
"""
success = 0
failed = 0
for icon_file in icons_dir.glob("**/*.png"):
try:
image = cv2.imread(str(icon_file), cv2.IMREAD_COLOR)
if image is None:
failed += 1
continue
item_name = icon_file.stem.replace('_', ' ').title()
if self.add_icon_to_database(item_name, image, category=category):
success += 1
else:
failed += 1
except Exception as e:
logger.error(f"Failed to add icon {icon_file}: {e}")
failed += 1
logger.info(f"Batch add complete: {success} success, {failed} failed")
return success, failed
def get_database_stats(self) -> Dict[str, Any]:
"""Get database statistics."""
return {
'total_icons': self.database.get_icon_count(),
'database_path': str(self.database.db_path),
'icons_directory': str(self.icons_dir)
}
def find_similar_icons(self, image: np.ndarray,
top_k: int = 5) -> List[MatchResult]:
"""Find top-k similar icons."""
avg_hash = PerceptualHash.average_hash(image)
# Get all matches
matches = self.database.find_by_hash(avg_hash, max_distance=20)
results = []
for item_name, similarity, meta in matches[:top_k]:
results.append(MatchResult(
item_name=item_name,
confidence=similarity,
match_method='hash',
item_id=meta.get('item_id'),
category=meta.get('category'),
metadata=meta.get('metadata', {})
))
return results
# Export main classes
__all__ = [
'IconMatcher',
'MatchResult',
'PerceptualHash',
'FeatureMatcher',
'TemplateMatcher',
'IconDatabase'
]