1.背景及目标
随着互联网技术的发展,图像数据呈指数级增长。图像搜索技术已经成为人们日常生活和工作中不可或缺的一部分,尤其是在电子商务、社交媒体、在线教育等领域。传统的基于文本的搜索引擎虽然已经非常成熟,但在面对大量无标签或标签不准确的图像时,其效果往往不尽如人意。相比之下,以图搜图(Image-to-Image Search, ITS)技术则能够直接利用图像本身的视觉特征进行检索,从而提高搜索的准确性和效率。
本项目旨在实现一个简单的以图搜图应用实例,具体目标是输入一张狗的图片,通过匹配数据库中最相似的三张图片,进而推测输入图片中的狗属于哪一类。
第一个是原图,后面三张是匹配的图数字表示相似度分数。这个小例子没有进行优化,较为简单,不涉及训练,准确率远远不如图像分类,主要用来熟悉流程以图搜图的流程和实现。
2.实现步骤
主要步骤
1. 数据准备
【数据集】11种犬类,共1089张
链接:https://pan.baidu.com/s/1sjlghIz_CXjXdLAkn030tQ
提取码:qlrt
2.图像预处理
对图像进行标准化处理,如缩放、裁剪、灰度化等,以保证图像的一致性。【本文无】
3. 特征提取
【网络模型】采用预训练resnet18进行特征提取,权重链接:
链接:https://pan.baidu.com/s/1TXZt6eo9F7lhHJP3QhFwQw
提取码:ou81
4.特征编码与存储
将提取到的特征点转换为紧凑的特征向量表示形式,并创建构建索引结构。【本文直接使用简单使用 h5py】
5.图像查询
特征提取:对查询图像应用与步骤3中相同的特征提取方法,获取其特征向量。
相似度计算:【本文简单使用向量积来计算】
scores = np.dot(queryVec, feats.T) 用来计算查询图像的特征向量(queryVec)与数据库中所有图像的特征向量(feats)之间的相似度得分
排序筛选:根据相似度得分对数据库中的图像进行排序,并选取最相似的若干图像。
6.结果展示:
结果显示:将最相似的图像进行展示,并可能包括额外的信息,如图像的类别、来源等。
3.具体代码实现
1.网络加载与特征提取
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec
@File :resnet18.py
@IDE :PyCharm
@Author :菜菜2024
@Date :2024/9/30
'''
from PIL import Image
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision import models
class ResNet18:
def __init__(self, model_path='E:\\xxxx\\weights\\resnet18.pth'):
self.trans = transforms.Compose([
transforms.Resize(size=(256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
print("-----------loading resnet18------------")
self.model = models.resnet18()
self.model.load_state_dict(torch.load(model_path))
num_feats = self.model.fc.in_features
self.model.fc = nn.Linear(num_feats, 128)
self.model.eval()
def extract_image_features(self, img_path):
image = Image.open(img_path).convert('RGB')
image_tensor = self.trans(image).unsqueeze(0)
with torch.no_grad():
features = self.model(image_tensor)
return features
2.图片向量化及保存
提取数据库图片特征,保存为向量pet_dog.h5文件
# -*- coding: utf-8 -*-
'''
@Project :ImageRec
@File :saveFeature.py
@IDE :PyCharm
@Author :菜菜2024
@Date :2024/9/30
'''
import os
import h5py
import numpy as np
from tqdm import tqdm
from model import ResNet18
ALLOW_EXT = {'png', 'jpg', 'jpeg'}
def getImagVectors(image_files, index_path):
'''
批量提取特征并保存为h5文件
:param image_files: 图片数据库路径
:param index_path: h5保存路径
:return:
'''
feats = []
image_ids = []
img_list = []
subdirs = os.listdir(image_files)
resnet18 = ResNet18()
print("*********开始进行特征提取*********")
for i in range(len(subdirs)):
imgs_path = [image_files +"\\" + subdirs[i]+ "\\"+ img_path for img_path in os.listdir(os.path.join(image_files, subdirs[i]))]
img_list.extend(imgs_path)
for ids, image_path in enumerate(tqdm(img_list)):
try:
jd_strings = image_path.split("\\")
image_id = jd_strings[-2] +"_" + jd_strings[-1].split(".")[0]
# 保存成“中华田园犬_0”的格式,为了展示时好取出对应图片
features = resnet18.extract_image_features(image_path).squeeze(0)
feats.append(features.detach().numpy())
image_ids.append(image_id.encode('utf-8'))
except Exception as e:
print(f"处理出错 image:{image_path}, 原因:{e}")
continue
feats = np.array(feats).astype('float32')
h5f = h5py.File(index_path, 'w')
h5f.create_dataset('图像特征', data=feats)
h5f.create_dataset('图像名称', data=np.string_(image_ids))
h5f.close()
print("*********图像向量化保存完成*********")
if __name__ == "__main__":
database = 'E:\xxx\datas\pet_dog'
index_path = './pet_dog.h5'
getImagVectors(database, index_path)
3.图像查询及可视化
对查询图像进行特征提取,加载第二步保存的h5文件,通过向量积的方式得到相似度得分,倒序排序,取对应的图像名称,到图片数据库种读取对应的图片,进行可视化。
# -*- coding: utf-8 -*-
'''
@Project :ImageRec
@File :test.py
@IDE :PyCharm
@Author :菜菜2024
@Date :2024/9/30
'''
import os.path
import numpy as np
import h5py
import matplotlib.image as mpimg
from model import ResNet18
import matplotlib.pyplot as plt
from matplotlib import rc
# 设置全局字体为支持中文的字体
rc('font', family='SimHei') # 黑体
def matchFeatureVector(indexfile, query):
'''
加载图片特征向量库,并对输入的查询图片进行向量化
:param indexfile:保存的特征h5文件路径
:param query:查询图片路径
:return:
'''
# 1.加载图片向量库
h5f = h5py.File(indexfile, 'r')
feats = h5f['图像特征'][:]
imgNames = h5f['图像名称'][:]
h5f.close()
#2.查询图片向量化
resnet18 = ResNet18()
queryVec = resnet18.extract_image_features(query).squeeze(0)
scores = np.dot(queryVec, feats.T)
rank_ID = np.argsort(scores)[::-1]
rank_score = scores[rank_ID]
return rank_ID, rank_score, imgNames
def showRes(indexfile, query, result, maxres):
print("*********查询图片向量*********")
rank_ID, rank_score, imgNames = matchFeatureVector(indexfile, query)
matchlist = []
imgs = []
info = []
queryImg = mpimg.imread(query)
imgs.append(queryImg)
info.append(query.split("/")[-1])
for i, index in enumerate(rank_ID[0:maxres]):
orgStr =str(imgNames[index], 'utf-8').split("_")
orgImg = orgStr[0]+"\\"+orgStr[-1]+".jpg"
imgs.append(mpimg.imread(os.path.join(result, orgImg)))
matchlist.append(orgImg)
info.append(orgImg + '_' + str(rank_score[i]))
print("图片名称是: " + orgImg + " 对应相似度得分是: %f" % rank_score[i])
# print("top %d 图片如下: " % maxres, matchlist)
num = int((maxres+1) // 2)
fig, axs = plt.subplots(nrows=num, ncols=num, figsize=(10, 10))
# 确保即使只有一个子图,也可以进行索引
if not isinstance(axs, np.ndarray):
axs = np.array([[axs]])
# 显示图像
flat_index = 0
for i in range(num):
for j in range(num):
if flat_index < len(imgs):
img = imgs[flat_index]
axs[i, j].imshow(img, cmap='gray')
axs[i, j].axis('off')
axs[i, j].set_title(info[flat_index])
flat_index += 1
else:
axs[i, j].set_visible(False)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
indexfile = './pet_dog.h5'
query = './data/pic/中华田园犬.jpg'
result = 'E:\xxx\datas\pet_dog'
showRes(indexfile, query, result, maxres=3)
未来方向
在以图搜图的应用中,遇到“相同的图片和程序允许两次,得到的结果还不一样”的问题,以及“准确率很低”的情况,通常表明特征提取和相似度计算存在缺陷。以下是针对这些问题的具体分析和优化方案:
问题分析
- 特征提取不具代表性:
使用预训练的 ResNet18 网络可能没有针对犬类这一特定领域的特征进行优化。 - 特征表示单一:
当前的特征提取方法仅从单张图片中提取特征向量,未能有效地概括类别特征。 - 相似度衡量方法不足:
当前使用的相似度衡量方法(如点积或余弦相似度)可能不够精确或不适合犬类图片的特征比较。
优化方向
- 提高特征表达能力
重新训练 ResNet18:
使用包含11个犬类类别的数据集重新训练 ResNet18 模型,使其更适应特定领域。
使用数据增强技术(如旋转、缩放、翻转等)来增加训练样本的多样性,从而提高模型的泛化能力。 - 使用类向量表示
类向量表示:
考虑从类别角度出发,提取类向量(class vector)。这可以通过对同一类别的多个样本进行平均或聚合操作来实现。使用聚类算法(如 K-Means)对同类图片的特征向量进行聚合,得到类别中心点,然后用类别中心点作为类向量进行相似度计算。 - 优化相似度衡量方法
引入 FAISS(Facebook AI Similarity Search)库来优化特征向量的存储和检索过程。FAISS 提供了高效的近似最近邻搜索算法,可以显著提高检索速度和准确性。