【以图搜图代码实现】--犬类以图搜图示例

news/2024/9/30 12:32:43 标签: python, 图搜索, 多分类, 计算机视觉

1.背景及目标

随着互联网技术的发展,图像数据呈指数级增长。图像搜索技术已经成为人们日常生活和工作中不可或缺的一部分,尤其是在电子商务、社交媒体、在线教育等领域。传统的基于文本的搜索引擎虽然已经非常成熟,但在面对大量无标签或标签不准确的图像时,其效果往往不尽如人意。相比之下,以图搜图(Image-to-Image Search, ITS)技术则能够直接利用图像本身的视觉特征进行检索,从而提高搜索的准确性和效率。

本项目旨在实现一个简单的以图搜图应用实例,具体目标是输入一张狗的图片,通过匹配数据库中最相似的三张图片,进而推测输入图片中的狗属于哪一类

在这里插入图片描述
第一个是原图,后面三张是匹配的图数字表示相似度分数。这个小例子没有进行优化,较为简单,不涉及训练,准确率远远不如图像分类,主要用来熟悉流程以图搜图的流程和实现。

2.实现步骤

主要步骤

1. 数据准备
数据集】11种犬类,共1089张
链接:https://pan.baidu.com/s/1sjlghIz_CXjXdLAkn030tQ
提取码:qlrt
在这里插入图片描述

2.图像预处理
对图像进行标准化处理,如缩放、裁剪、灰度化等,以保证图像的一致性。【本文无】

3. 特征提取
网络模型】采用预训练resnet18进行特征提取,权重链接:
链接:https://pan.baidu.com/s/1TXZt6eo9F7lhHJP3QhFwQw
提取码:ou81

4.特征编码与存储
将提取到的特征点转换为紧凑的特征向量表示形式,并创建构建索引结构。【本文直接使用简单使用 h5py】

5.图像查询
特征提取:对查询图像应用与步骤3中相同的特征提取方法,获取其特征向量。
相似度计算:【本文简单使用向量积来计算】

scores = np.dot(queryVec, feats.T) 用来计算查询图像的特征向量(queryVec)与数据库中所有图像的特征向量(feats)之间的相似度得分

排序筛选:根据相似度得分对数据库中的图像进行排序,并选取最相似的若干图像。
6.结果展示:
结果显示:将最相似的图像进行展示,并可能包括额外的信息,如图像的类别、来源等。

3.具体代码实现

1.网络加载与特征提取

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec
@File    :resnet18.py
@IDE     :PyCharm
@Author  :菜菜2024
@Date    :2024/9/30
'''
from PIL import Image
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision import models

class ResNet18:
    def __init__(self, model_path='E:\\xxxx\\weights\\resnet18.pth'):
        self.trans = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
        print("-----------loading resnet18------------")
        self.model = models.resnet18()
        self.model.load_state_dict(torch.load(model_path))
        num_feats = self.model.fc.in_features
        self.model.fc = nn.Linear(num_feats, 128)
        self.model.eval()


    def extract_image_features(self, img_path):

        image = Image.open(img_path).convert('RGB')
        image_tensor = self.trans(image).unsqueeze(0)
        with torch.no_grad():
            features = self.model(image_tensor)
        return features

2.图片向量化及保存

提取数据库图片特征,保存为向量pet_dog.h5文件


# -*- coding: utf-8 -*-
'''
@Project :ImageRec
@File    :saveFeature.py
@IDE     :PyCharm
@Author  :菜菜2024
@Date    :2024/9/30
'''

import os
import h5py
import numpy as np
from tqdm import tqdm
from model import ResNet18

ALLOW_EXT = {'png', 'jpg', 'jpeg'}

def getImagVectors(image_files, index_path):
    '''
      批量提取特征并保存为h5文件
      :param image_files: 图片数据库路径
      :param index_path: h5保存路径
      :return:
      '''


    feats = []
    image_ids = []
    img_list = []
    subdirs = os.listdir(image_files)
    resnet18 = ResNet18()

    print("*********开始进行特征提取*********")
    for i in range(len(subdirs)):
        imgs_path = [image_files +"\\" + subdirs[i]+ "\\"+ img_path for img_path in os.listdir(os.path.join(image_files, subdirs[i]))]

        img_list.extend(imgs_path)

    for ids, image_path in enumerate(tqdm(img_list)):
        try:
            jd_strings = image_path.split("\\")

            image_id = jd_strings[-2] +"_" +  jd_strings[-1].split(".")[0]
            # 保存成“中华田园犬_0”的格式,为了展示时好取出对应图片
            features = resnet18.extract_image_features(image_path).squeeze(0)
            feats.append(features.detach().numpy())
            image_ids.append(image_id.encode('utf-8'))
        except Exception as e:
            print(f"处理出错 image:{image_path}, 原因:{e}")
            continue

    feats = np.array(feats).astype('float32')

    h5f = h5py.File(index_path, 'w')
    h5f.create_dataset('图像特征', data=feats)
    h5f.create_dataset('图像名称', data=np.string_(image_ids))
    h5f.close()

    print("*********图像向量化保存完成*********")



if __name__ == "__main__":

    database = 'E:\xxx\datas\pet_dog'
    index_path = './pet_dog.h5'
    getImagVectors(database, index_path)
    

在这里插入图片描述

3.图像查询及可视化

对查询图像进行特征提取,加载第二步保存的h5文件,通过向量积的方式得到相似度得分,倒序排序,取对应的图像名称,到图片数据库种读取对应的图片,进行可视化。

# -*- coding: utf-8 -*-
'''
@Project :ImageRec
@File    :test.py
@IDE     :PyCharm
@Author  :菜菜2024
@Date    :2024/9/30
'''

import os.path
import numpy as np
import h5py
import matplotlib.image as mpimg
from model import ResNet18
import matplotlib.pyplot as plt
from matplotlib import rc

# 设置全局字体为支持中文的字体
rc('font', family='SimHei')  # 黑体
def matchFeatureVector(indexfile, query):
    '''
    加载图片特征向量库,并对输入的查询图片进行向量化
    :param indexfile:保存的特征h5文件路径
    :param query:查询图片路径
    :return:
    '''

    # 1.加载图片向量库
    h5f = h5py.File(indexfile, 'r')
    feats = h5f['图像特征'][:]
    imgNames = h5f['图像名称'][:]
    h5f.close()

    #2.查询图片向量化
    resnet18 = ResNet18()
    queryVec = resnet18.extract_image_features(query).squeeze(0)

    scores = np.dot(queryVec, feats.T)
    rank_ID = np.argsort(scores)[::-1]
    rank_score = scores[rank_ID]

    return rank_ID, rank_score, imgNames


def showRes(indexfile, query, result, maxres):

    print("*********查询图片向量*********")
    rank_ID, rank_score, imgNames = matchFeatureVector(indexfile, query)

    matchlist = []
    imgs = []
    info = []

    queryImg = mpimg.imread(query)
    imgs.append(queryImg)
    info.append(query.split("/")[-1])

    for i, index in enumerate(rank_ID[0:maxres]):

        orgStr =str(imgNames[index], 'utf-8').split("_")
        orgImg = orgStr[0]+"\\"+orgStr[-1]+".jpg"

        imgs.append(mpimg.imread(os.path.join(result, orgImg)))
        matchlist.append(orgImg)
        info.append(orgImg + '_' + str(rank_score[i]))

        print("图片名称是: " + orgImg + " 对应相似度得分是: %f" % rank_score[i])
    # print("top %d 图片如下: " % maxres, matchlist)

    num = int((maxres+1) // 2)

    fig, axs = plt.subplots(nrows=num, ncols=num, figsize=(10, 10))

    # 确保即使只有一个子图,也可以进行索引
    if not isinstance(axs, np.ndarray):
        axs = np.array([[axs]])

    # 显示图像
    flat_index = 0
    for i in range(num):
        for j in range(num):
            if flat_index < len(imgs):
                img = imgs[flat_index]
                axs[i, j].imshow(img, cmap='gray')
                axs[i, j].axis('off')
                axs[i, j].set_title(info[flat_index])
                flat_index += 1
            else:

                axs[i, j].set_visible(False)

    plt.tight_layout()
    plt.show()



if __name__ == "__main__":
    indexfile = './pet_dog.h5'
    query = './data/pic/中华田园犬.jpg'
    result = 'E:\xxx\datas\pet_dog'
    showRes(indexfile, query, result, maxres=3)

未来方向

在以图搜图的应用中,遇到“相同的图片和程序允许两次,得到的结果还不一样”的问题,以及“准确率很低”的情况,通常表明特征提取和相似度计算存在缺陷。以下是针对这些问题的具体分析和优化方案:

问题分析

  1. 特征提取不具代表性:
    使用预训练的 ResNet18 网络可能没有针对犬类这一特定领域的特征进行优化。
  2. 特征表示单一:
    当前的特征提取方法仅从单张图片中提取特征向量,未能有效地概括类别特征。
  3. 相似度衡量方法不足:
    当前使用的相似度衡量方法(如点积或余弦相似度)可能不够精确或不适合犬类图片的特征比较。

优化方向

  1. 提高特征表达能力
    重新训练 ResNet18:
    使用包含11个犬类类别的数据集重新训练 ResNet18 模型,使其更适应特定领域。
    使用数据增强技术(如旋转、缩放、翻转等)来增加训练样本的多样性,从而提高模型的泛化能力。
  2. 使用类向量表示
    类向量表示:
    考虑从类别角度出发,提取类向量(class vector)。这可以通过对同一类别的多个样本进行平均或聚合操作来实现。使用聚类算法(如 K-Means)对同类图片的特征向量进行聚合,得到类别中心点,然后用类别中心点作为类向量进行相似度计算。
  3. 优化相似度衡量方法
    引入 FAISS(Facebook AI Similarity Search)库来优化特征向量的存储和检索过程。FAISS 提供了高效的近似最近邻搜索算法,可以显著提高检索速度和准确性。

http://www.niftyadmin.cn/n/5685107.html

相关文章

GEE数据集:1996 年到 2020 年全球红树林观测数据集(JAXA)(更新)

目录 简介 数据集说明 数据集 代码 代码链接 结果 引用 许可 网址推荐 0代码在线构建地图应用 机器学习 简介 全球红树林观测 这项研究使用了日本宇宙航空研究开发机构&#xff08;JAXA&#xff09;提供的 L 波段合成孔径雷达&#xff08;SAR&#xff09;全球mask…

buff叠满!软考报名越晚,批次越晚?考试越难?

近日&#xff0c;各地软考办都发布了2024年下半年软考批次安排。 报考了软考中级-系统集成项目管理工程师&#xff08;简称“集成”&#xff09;的广东考生炸锅了&#xff0c;我会被分到11月11日&#xff08;周一&#xff09;的第四批次、第五批次考试吗&#xff1f; 软考批次是…

CSS 中的@media print 是干什么用的?

media print { ... } 是CSS中的一个媒体查询&#xff0c;它专门用于定义当内容被打印到纸张上时应该应用的样式规则。在这个查询块内&#xff0c;你可以设置各种样式&#xff0c;以确保打印输出的内容看起来整洁、专业&#xff0c;并且只包含必要的信息。 在你给出的例子中&am…

Spring - @Import注解

文章目录 基本用法源码分析ConfigurationClassPostProcessorConfigurationClass SourceClassgetImportsprocessImports处理 ImportSelectorImportSelector 接口DeferredImportSelector 处理 ImportBeanDefinitionRegistrarImportBeanDefinitionRegistrar 接口 处理Configuratio…

《动手学深度学习》笔记2.5——神经网络从基础→使用GPU (CUDA-单卡-多卡-张量操作)

目录 0. 前言 原书正文 1. 计算设备 (CPU和GPU) 补充&#xff1a;torch版本cuda报错的解决方案 2. 张量与GPU 3. 存储在GPU上 4. 复制&#xff08;多卡操作&#xff09; 5. 旁注 (CPU和GPU之间挪数据) 6. 神经网络与GPU 小结 0. 前言 课程全部代码&#xff08;pytorc…

《ToDesk 云电脑、易腾云、青椒云移动端体验实测:让手机秒变超级电脑》

前言 科技发展到如今2024年&#xff0c;可以说每一年都在发生翻天覆地的变化。云电脑这个市场近年来迅速发展&#xff0c;无需购买和维护额外的硬件就可以体验到电脑端顶配的性能和体验&#xff0c;并且移动端也可以带来非凡体验。我们在外出办公随身没有携带电脑情况下&#x…

【C++——文件操作】

写入 #include<iostream> #include<fstream> //ofstream所需头文件 using namespace std;int main() {//一打开文件:string str R"(C:\Users\admin\Desktop\新建文件夹\test.txt)";//也可以用C风格字符串//打开文件&#xff0c;如果不存在就创建一…

Llama微调以及Ollama部署

1 Llama微调 在基础模型的基础上&#xff0c;通过一些特定的数据集&#xff0c;将具有特定功能加在原有的模型上。 1.1 效果对比 特定数据集 未使用微调的基础模型的回答 使用微调后的回答 1.2 基础模型 基础大模型我选择Mistral-7B-v0.3-Chinese-Chat-uncensored&#x…