抠图党福音:教你一键分割图像

adminmycode2年前python1011

Segment Anything

Segment Anything Model(SAM)通过点或框等输入提示生成高质量的对象分割区域,并且可以用于为图像中的所有对象生成分割区域。它已经在 1100 万张图像和 11 亿个分割区域的数据集上进行了训练,并且在各种分割任务上具有强大的零样本性能。

SAM 的工作原理:可提示分割

在自然语言处理和最近的计算机视觉领域,最令人兴奋的发展之一是基础模型的发展,这些基础模型可以使用提示技术(prompting)对新数据集和任务执行零样本和小样本学习。我们从这类工作中汲取了灵感。

我们训练 SAM 为任何提示返回有效的分割掩码,其中提示可以是前景 / 背景点、粗框或掩码、自由格式文本。或者一般来说,提示图像中要分割的内容的任何信息。有效掩码的要求仅仅意味着即使提示不明确并且可能指代多个对象(例如,衬衫上的一个点可能表示衬衫或穿着它的人),输出也应该是一个合理的掩码对象之一。此任务用于预训练模型并通过提示解决一般的下游分割任务。

我们观察到预训练任务和交互式数据收集对模型设计施加了特定的限制。特别是,该模型需要在 Web 浏览器的 CPU 上实时运行,以允许我们的标注者实时交互地使用 SAM 以高效地进行标注。虽然运行时限制意味着质量和运行时之间的权衡,但我们发现简单的设计在实践中会产生良好的结果。具体地,图像编码器为图像生成一次性嵌入向量,而轻量级编码器将任何提示实时转换为嵌入向量。然后将这两个信息源组合在一个预测分割掩码的轻量级解码器中。在计算图像嵌入后,SAM 可以在 50 毫秒内根据网络浏览器中的任何提示生成一个分割。

SAM 模型总体上分为 3 部分:

绿色的图像编码器,基于可扩展和强大的预训练方法,我们使用 MAE 预训练的 ViT,最小限度地适用于处理高分辨率输入。图像编码器对每张图像运行一次,在提示模型之前进行应用。

紫色的提示编码器,考虑两组 prompt:稀疏(点、框、文本)和密集(掩码)。我们通过位置编码来表示点和框,并将对每个提示类型的学习嵌入自由形式的文本与 CLIP 中的现成文本编码相加。密集的提示(即掩码)使用卷积进行嵌入,并通过图像嵌入进行元素求和。

橙色的提示编码器,掩码解码器有效地将图像嵌入、提示嵌入和输出 token 映射到掩码。该设计的灵感来自于 DETR,采用了对(带有动态掩模预测头的)Transformer decoder 模块的修改。

Segment Anything 适配 ModelArts

使用方法:

输入一个图像,通过 Segment Anything 模型即可获得图像所有目标的分割点位置,再通过位置将图像进行分割保存。

本案例需使用 Pytorch-1.8 GPU-P100 及以上规格运行

点击 Run in ModelArts,将会进入到 ModelArts CodeLab 中,这时需要你登录华为云账号,如果没有账号,则需要注册一个,且要进行实名认证,参考《ModelArts 准备工作_简易版》 即可完成账号注册和实名认证。登录之后,等待片刻,即可进入到 CodeLab 的运行环境

出现 Out Of Memory ,请检查是否为您的参数配置过高导致,修改参数配置,重启 kernel 或更换更高规格资源进行规避❗❗❗

1. 环境准备

为了方便用户下载使用及快速体验,本案例已将代码及 segment-anything 预训练模型转存至华为云 OBS 中。模型下载与加载需要几分钟时间。

import osimport torch
import os.path as osp
import moxing as moxpath = osp.join(os.getcwd(),'segment-anything')if not os.path.exists(path):
 mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/segment-anything', path) if os.path.exists(path): print('Download success') else:
        raise Exception('Download Failed')else: print("Model Package already exists!")

check GPU & 安装依赖

大约耗时 1min

%cd segment-anything
!pip install --upgrade pip
!pip install torch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1!pip install opencv-python matplotlib
!python setup.py installimport numpy as npimport matplotlib.pyplot as pltimport cv2import copyimport torchimport torchvisionprint("PyTorch version:", torch.__version__)print("Torchvision version:", torchvision.__version__)print("CUDA is available:", torch.cuda.is_available())

2. 加载模型

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam_checkpoint = "sam_vit_h_4b8939.pth"model_type = "vit_h"device = "cuda"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(
    model=sam,    #points_per_side=32,
    #pred_iou_thresh=0.86,
    #stability_score_thresh=0.92,
    #crop_n_layers=1,
    #crop_n_points_downscale_factor=2,
    #min_mask_region_area=100,  # Requires open-cv to run post-processing)

3. 一键分割所有目标

def show_anns(anns,image):
 segment_image = copy.copy(image)
 segment_image.astype("uint8")
 if len(anns) == 0:
 return
 sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
 for ann in sorted_anns:
        mask_2d = ann['segmentation']
 h,w = mask_2d.shape
        mask_3d_color = np.zeros((h,w,3), dtype=np.uint8)
        mask = (mask_2d!=0).astype(bool)
 rgb = np.random.randint(0, 255, (1, 3), dtype=np.uint8)
        mask_3d_color[mask_2d[:, :] == 1] = rgb
 segment_image[mask] = segment_image[mask] * 0.5 + mask_3d_color[mask] * 0.5
 return segment_imageimage = cv2.imread('images/dog.jpg')image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)masks = mask_generator.generate(image)segment_image = show_anns(masks,image)fig = plt.figure(figsize=(25, 10))ax1 = fig.add_subplot(1, 2, 1)plt.title('Original image', fontsize=16)ax1.axis('off')ax1.imshow(image)ax2 = fig.add_subplot(1, 2, 2)plt.title('Segment image', fontsize=16)ax2.axis('off')ax2.imshow(segment_image)plt.show()

4. 保存所有分割的图片

将所有识别出来的分割位置进行分割,并保存成图片。

def apply_mask(image, mask, alpha_channel=True):#应用并且响应mask
 if alpha_channel:
        alpha = np.zeros_like(image[..., 0])#制作掩体
        alpha[mask == 1] = 255#兴趣地方标记为1,且为白色
        image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha))#融合图像
 else:
        image = np.where(mask[..., None] == 1, image, 0) return imagedef mask_image(image, mask, crop_mode_=True):#保存掩盖部分的图像(感兴趣的图像)
 if crop_mode_:
        y, x = np.where(mask)
 y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
 cropped_mask = mask[y_min:y_max+1, x_min:x_max+1]
 cropped_image = image[y_min:y_max+1, x_min:x_max+1]
 masked_image = apply_mask(cropped_image, cropped_mask) else:
 masked_image = apply_mask(image, mask) return masked_imagedef save_masked_image(image, filepath): if image.shape[-1] == 4:
        cv2.imwrite(filepath, image, [cv2.IMWRITE_PNG_COMPRESSION, 9]) else:
        cv2.imwrite(filepath, image)
 print(f"Saved as {filepath}")def save_anns(anns,image,path): if len(anns) == 0: return
 sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)    index = 1
 for ann in sorted_anns:
        mask_2d = ann['segmentation']
 segment_image = copy.copy(image)
 masked_image = mask_image(segment_image, mask_2d)
        filename = str(index) + '.png'
 filepath = os.path.join(path, filename)
 save_masked_image(masked_image, filepath)        index = index + 1save_path = 'result/'if not os.path.exists(save_path):
 os.mkdir(save_path)
image = cv2.imread('images/dog.jpg')
masks = mask_generator.generate(image)
save_anns(masks,image,save_path)

5. Gradio 可视化部署

为了方便大家使用一键分割案例,当前增加了 Gradio 可视化部署案例演示。

运行如下代码,Gradio 应用启动后可在下方页面进行一键分割图像,您也可以分享 public url 在手机端,PC 端进行访问生成图像。

示例效果如下:

!pip install gradio==3.24.1def segment_image(image):
    masks = mask_generator.generate(image) return show_anns(masks,image)def show_image(image):
    masks = mask_generator.generate(image) if len(masks) == 0: return
 sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)    index = 1
 image_list = [] for ann in sorted_anns:
        mask_2d = ann['segmentation']
 segment_image = copy.copy(image)
 masked_image = mask_image(segment_image, mask_2d)
 image_list.append(masked_image) return image_listimport gradio as grwith gr.Blocks() as demo: with gr.Row(): with gr.Column():
 img_in = gr.Image(source='upload') with gr.Row():
 segment_button = gr.Button("segment",variant="primary")
 save_button = gr.Button("segment_images",variant="primary") with gr.Row(): with gr.Column():
 img_out = gr.Image() with gr.Row():
 result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=6, height='auto')
 segment_button.click(segment_image,
                 inputs= [img_in], 
                 outputs=[img_out])
 save_button.click(show_image,
                 inputs= [img_in], 
                 outputs=[result_gallery])
demo.launch(share=True)


标签: Python

相关文章

[Python从零到壹] 二.语法基础之条件语句、循环语句和函数

一.条件语句在讲诉条件语句之前,需要先补充语句块的知识。语句块并非一种语句,它是在条件为真时执行一次或执行多次的一组语句,在代码前放置空格缩进即可创建语句块。它类似于C、C++、Java等语言的大括号...

Python 从零到壹丨详解图像锐化 Roberts、Prewitt 算子实现边缘检测

Python 从零到壹丨详解图像锐化 Roberts、Prewitt 算子实现边缘检测

一。图像锐化由于收集图像数据的器件或传输图像的通道存在一些质量缺陷,或者受其他外界因素的影响,使得图像存在模糊和有噪声的情况,从而影响到图像识别工作的开展。一般来说,图像的能量主要集中在其低频部分,噪...

[Python从零到壹] 三.语法基础之文件操作、CSV文件读写及面向对象 | 【生长吧!Python】

一.文件操作文件是指存储在外部介质上数据的集合,文本文件编码方式包括ASCII格式、Unicode码、UTF-8码、GBK编码等。文件的操作流程为“打开文件-读写文件-关闭文件”三部曲。1.打开文件打...

Python从零到壹丨带你了解图像直方图理论知识和绘制实现

Python从零到壹丨带你了解图像直方图理论知识和绘制实现

一.图像直方图理论知识灰度直方图是灰度级的函数,描述的是图像中每种灰度级像素的个数,反映图像中每种灰度出现的频率。假设存在一幅6×6像素的图像,接着统计其1至6灰度级的出现频率,并绘制如图1所示的柱状...

Python 从零到壹丨带你了解图像直方图理论知识和绘制实现

Python 从零到壹丨带你了解图像直方图理论知识和绘制实现

一。图像直方图理论知识灰度直方图是灰度级的函数,描述的是图像中每种灰度级像素的个数,反映图像中每种灰度出现的频率。假设存在一幅 6×6 像素的图像,接着统计其 1 至 6 灰度级的出现频率,并绘制如图...

[Python从零到壹] 五.网络爬虫之BeautifulSoup基础语法万字详解 | 【生长吧!Python】

一.安装BeautifulSoupBeautifulSoup是一个可以从HTML或XML文件中提取数据的Python扩展库。BeautifulSoup通过合适的转换器实现文档导航、查找、修改文档等。它...

发表评论    

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。