114python之家

您现在的位置是:首页 > python > 正文

python

一、使用paddleseg套件对遥感影像预测(基础)

admin2021-02-11python417
【114python之家】一、使用paddleseg套件对遥感影像预测(基础)目前paddleseg套件中的predict.py代码文件还不支持直接对遥感影像(大图)做预测,或者说把遥感大图直接丢进p

一、使用paddleseg套件对遥感影像预测(基础)


目前paddleseg套件中的predict.py代码文件还不支持直接对遥感影像(大图)做预测,或者说把遥感大图直接丢进predict.py,它的预测效果非常差。

基于以上问题,本文结合paddleseg中predict.py源码和这篇博文代码(遥感语义分割切图预测之后再拼接)重新写了predict.py代码,希望可以帮助到使用飞桨框架做遥感影像语义分割的朋友,所以这里需要你会使用paddleseg套件或者对paddleseg源码有所了解,这里有位博主写了一系列有关paddleseg源码的文章,值得参考学习(人工智能研习社)。

重新写的predict.py代码主要分为四个部分
读取和裁剪遥感大图网络模型推理预测小图块拼接小图块预测结果拼接结果写入文件

1、读取待预测遥感大图,将遥感大图裁剪成小图块,这里裁剪的小图块相邻之间不设置重叠度,小图块大小为256x256。
本部分代码如下:

#读取需要预测的遥感大图img_lists[local_rank][local_rank]=/home/aistudio/data/data70483/img.png
ori_image=cv2.imread(img_lists[local_rank][local_rank])
h_step=ori_image.shape[0]//256#高度步数
w_step=ori_image.shape[1]//256#宽度步数

h_rest=-(ori_image.shape[0]-256*h_step)#剩余行数
w_rest=-(ori_image.shape[1]-256*w_step)#剩余列数

seg_list=[]#小图块的列表
predict_list=[]#预测小图块结果的列表
#循环切图
forhinrange(h_step):
forwinrange(w_step):
#划窗采样
image_sample=ori_image[(h*256):(h*256+256),
(w*256):(w*256+256),:]
seg_list.append(image_sample)
seg_list.append(ori_image[(h*256):(h*256+256),-256:,:])
forwinrange(w_step-1):
seg_list.append(ori_image[-256:,(w*256):(w*256+256),:])
seg_list.append(ori_image[-256:,-256:,:])

2、利用网络模型推理预测小图块,这里的代码改动不多,但是需要将img_lists[local_rank]参数改成存储小图块的列表seg_list,其他参数的设置根据需要而定,在这里本文只对小图块做最普通的推理预测,既不做多尺度预测、也不做滑窗预测(多尺度和滑窗预测是原本predict.py的功能,当然在这里我们也可以用)。
本部分代码如下:

progbar_pred=progbar.Progbar(target=len(seg_list),verbose=1)
withpaddle.no_grad():
fori,iminenumerate(seg_list):
ori_shape=im.shape[:2]#原始图片形状(h,w)
im,_=transforms(im)#im.shape(3,256,256)_为None
im=im[np.newaxis,...]#im.shape(1,3,256,256)
im=paddle.to_tensor(im)

ifFalse:
pred=infer.aug_inference(
model,
im,
ori_shape=ori_shape,
transforms=transforms.transforms,
scales=scales,
flip_horizontal=flip_horizontal,
flip_vertical=flip_vertical,
is_slide=is_slide,
stride=None,
crop_size=None)
else:
pred=infer.inference(
model,
im,
ori_shape=ori_shape,
transforms=transforms.transforms,
is_slide=False,
stride=None,
crop_size=None)
pred=paddle.squeeze(pred)#该OP会删除输入Tensor的Shape中尺寸为1的维度。查看pred的形状应该剩下[h,w]
pred=pred.numpy().astype('uint8')
predict_list.append(pred)
progbar_pred.update(i+1)

3、将小图块的预测结果进行拼接,这里的拼接思想很简单,就是按照裁剪的顺序进行拼接。
本部分代码如下:

count_temp=0
tmp=np.ones([ori_image.shape[0],ori_image.shape[1]])
forhinrange(h_step):
forwinrange(w_step):
tmp[
h*256:(h+1)*256,
w*256:(w+1)*256
]=predict_list[count_temp]
count_temp+=1
tmp[h*256:(h+1)*256,w_rest:]=predict_list[count_temp][:,w_rest:]
count_temp+=1
forwinrange(w_step-1):
tmp[h_rest:,(w*256):(w*256+256)]=predict_list[count_temp][h_rest:,:]
count_temp+=1
tmp[-257:-1,-257:-1]=predict_list[count_temp][:,:]

4、将拼接结果tmp写入图像文件中,这里使用了原先predict.py的写入函数,只是将函数中pred参数改成了tmp,需要注意的是一定要将tmp变量提前转换为uint8类型,不然程序会报错。
本部分代码如下:

tmp=tmp.astype('uint8')
#saveaddedimage
added_image=utils.visualize.visualize(args.image_path,tmp,weight=0.6)
added_image_path=os.path.join(added_saved_dir,im_file)
mkdir(added_image_path)
cv2.imwrite(added_image_path,added_image)
#savepseudocolorprediction
pred_mask=utils.visualize.get_pseudo_color_map(tmp)
pred_saved_path=os.path.join(pred_saved_dir,
im_file.rsplit(".")[0]+".png")
mkdir(pred_saved_path)
pred_mask.save(pred_saved_path)

到这里代码的主体部分基本上搞定了,值得注意的是paddleseg套件中predict.py会从paddleseg.core调用predict.py,而本文为了方便移植代码,就将两个predict.py写成了一个predict.py。

当时写的第一个版本predict.py将裁剪的小图块尺寸设定为了256,同时将inference有些参数都设定死了,所以不推荐直接copy使用,仅作为参考学习。第一个版本predict.py完整代码如下:

importsys
importargparse
importos
importpaddle
frompaddleseg.cvlibsimportmanager,Config
frompaddleseg.utilsimportget_sys_env,logger
importmath
importcv2
importnumpyasnp
frompaddlesegimportutils
frompaddleseg.coreimportinfer
frompaddleseg.utilsimportprogbar

defmkdir(path):
sub_dir=os.path.dirname(path)#去掉文件名,返回目录
ifnotos.path.exists(sub_dir):
os.makedirs(sub_dir)

defpartition_list(arr,m):
"""splitthelist'arr'intompieces"""
n=int(math.ceil(len(arr)/float(m)))
return[arr[i:i+n]foriinrange(0,len(arr),n)]

defparse_args():
parser=argparse.ArgumentParser(description='Modelprediction')

#paramsofprediction
parser.add_argument(
"--config",dest="cfg",help="Theconfigfile.",default=None,type=str)
parser.add_argument(
'--model_path',
dest='model_path',
help='Thepathofmodelforprediction',
type=str,
default=None)
parser.add_argument(
'--image_path',
dest='image_path',
help=
'Thepathofimage,itcanbeafileoradirectoryincludingimages',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='Thedirectoryforsavingthepredictedresults',
type=str,
default='./output/result')

#augmentforprediction
parser.add_argument(
'--aug_pred',
dest='aug_pred',
help='Whethertousemulit-scalesandflipaugmentforprediction',
action='store_true')
parser.add_argument(
'--scales',
dest='scales',
nargs='+',
help='Scalesforaugment',
type=float,
default=1.0)
parser.add_argument(
'--flip_horizontal',
dest='flip_horizontal',
help='Whethertousefliphorizontallyaugment',
action='store_true')
parser.add_argument(
'--flip_vertical',
dest='flip_vertical',
help='Whethertouseflipverticallyaugment',
action='store_true')

#slidingwindowprediction
parser.add_argument(
'--is_slide',
dest='is_slide',
help='Whethertopredictionbyslidingwindow',
action='store_true')
parser.add_argument(
'--crop_size',
dest='crop_size',
nargs=2,
help=
'Thecropsizeofslidingwindow,thefirstiswidthandthesecondisheight.',
type=int,
default=None)
parser.add_argument(
'--stride',
dest='stride',
nargs=2,
help=
'Thestrideofslidingwindow,thefirstiswidthandthesecondisheight.',
type=int,
default=None)

returnparser.parse_args()

defget_image_list(image_path):
"""Getimagelist"""
valid_suffix=[
'.JPEG','.jpeg','.JPG','.jpg','.BMP','.bmp','.PNG','.png','.tif'
]
image_list=[]
image_dir=None
ifos.path.isfile(image_path):
ifos.path.splitext(image_path)[-1]invalid_suffix:
image_list.append(image_path)
elifos.path.isdir(image_path):
image_dir=image_path
forroot,dirs,filesinos.walk(image_path):#root=image_path
forfinfiles:
ifos.path.splitext(f)[-1]invalid_suffix:
image_list.append(os.path.join(root,f))
else:
raiseFileNotFoundError(
'`--image_path`isnotfound.itshouldbeanimagefileoradirectoryincludingimages'
)

iflen(image_list)==0:
raiseRuntimeError('Therearenotimagefilein`--image_path`')

returnimage_list,image_dir#返回测试文件列表


defmain(args):
env_info=get_sys_env()
place='gpu'ifenv_info['Paddlecompiledwithcuda']andenv_info[
'GPUsused']else'cpu'

paddle.set_device(place)
ifnotargs.cfg:
raiseRuntimeError('Noconfigurationfilespecified.')

cfg=Config(args.cfg)
val_dataset=cfg.val_dataset
ifnotval_dataset:
raiseRuntimeError(
'Theverificationdatasetisnotspecifiedintheconfigurationfile.'
)

msg='
---------------ConfigInformation---------------
'
msg+=str(cfg)
msg+='------------------------------------------------'
logger.info(msg)

model=cfg.model
transforms=val_dataset.transforms
#image_list,image_dir=get_image_list('data/UAV_seg/images')
image_list,image_dir=get_image_list(args.image_path)#需要传入args.image_path参数这个参数可以是测试图片的路径,也可以是单张图片的路径

model_path=args.model_path,#传入训练模型的路径
save_dir=args.save_dir,
aug_pred=False,
scales=1.0,
flip_horizontal=True,
flip_vertical=False,
is_slide=False,
stride=None,
crop_size=None

para_state_dict=paddle.load(model_path[0])
model.set_dict(para_state_dict)
model.eval()
nranks=paddle.distributed.get_world_size()
local_rank=paddle.distributed.get_rank()
ifnranks>1:
img_lists=partition_list(image_list,nranks)
else:
img_lists=[image_list]#列表的列表img_lists[0]->列表

added_saved_dir=os.path.join(save_dir[0],'added_prediction')#伪彩色和原图叠加
pred_saved_dir=os.path.join(save_dir[0],'pseudo_color_prediction')#伪彩色预测结果

logger.info("Starttopredict...")


##############################1、裁剪遥感大图########################
#读取需要预测的遥感大图img_lists[local_rank][local_rank]=/home/aistudio/data/data70483/img.png
ori_image=cv2.imread(img_lists[local_rank][local_rank])
h_step=ori_image.shape[0]//256#高度步数
w_step=ori_image.shape[1]//256#宽度步数

h_rest=-(ori_image.shape[0]-256*h_step)#剩余行数
w_rest=-(ori_image.shape[1]-256*w_step)#剩余列数

seg_list=[]#由遥感大图裁剪成小图块的列表
predict_list=[]#预测小图块结果的列表
#循环切图
forhinrange(h_step):
forwinrange(w_step):
#划窗采样
image_sample=ori_image[(h*256):(h*256+256),
(w*256):(w*256+256),:]
seg_list.append(image_sample)
seg_list.append(ori_image[(h*256):(h*256+256),-256:,:])
forwinrange(w_step-1):
seg_list.append(ori_image[-256:,(w*256):(w*256+256),:])
seg_list.append(ori_image[-256:,-256:,:])
##############################裁剪结束########################

##############################2、利用网络模型推理小图块########################
progbar_pred=progbar.Progbar(target=len(seg_list),verbose=1)
withpaddle.no_grad():
fori,iminenumerate(seg_list):
ori_shape=im.shape[:2]#原始图片形状(h,w)
im,_=transforms(im)#im.shape(3,256,256)_为None
im=im[np.newaxis,...]#im.shape(1,3,256,256)
im=paddle.to_tensor(im)

ifFalse:
pred=infer.aug_inference(
model,
im,
ori_shape=ori_shape,
transforms=transforms.transforms,
scales=scales,
flip_horizontal=flip_horizontal,
flip_vertical=flip_vertical,
is_slide=is_slide,
stride=None,
crop_size=None)
else:
pred=infer.inference(
model,
im,
ori_shape=ori_shape,
transforms=transforms.transforms,
is_slide=False,
stride=None,
crop_size=None)
pred=paddle.squeeze(pred)#该OP会删除输入Tensor的Shape中尺寸为1的维度。查看pred的形状应该剩下[h,w]
pred=pred.numpy().astype('uint8')
predict_list.append(pred)
progbar_pred.update(i+1)
##############################推理结束########################

#############3、将预测后的图像块再拼接起来########################
count_temp=0
tmp=np.ones([ori_image.shape[0],ori_image.shape[1]])
forhinrange(h_step):
forwinrange(w_step):
tmp[
h*256:(h+1)*256,
w*256:(w+1)*256
]=predict_list[count_temp]
count_temp+=1
tmp[h*256:(h+1)*256,w_rest:]=predict_list[count_temp][:,w_rest:]
count_temp+=1
forwinrange(w_step-1):
tmp[h_rest:,(w*256):(w*256+256)]=predict_list[count_temp][h_rest:,:]
count_temp+=1
tmp[-257:-1,-257:-1]=predict_list[count_temp][:,:]
##################拼接结束########################

#获取需要保存的图片名称,去掉前面的路径
#getthesavedname
ifimage_dirisnotNone:
pass
#im_file=im_path.replace(image_dir,'')#例:将PaddleSeg/data/optic_disc_seg/JPEGImages/P0011.jpg替换为/P0011.jpg
else:
im_file=os.path.basename(img_lists[local_rank][local_rank])#带后缀名
ifim_file[0]=='/':#去掉/
im_file=im_file[1:]

#############
tmp=tmp.astype('uint8')
#saveaddedimage
added_image=utils.visualize.visualize(args.image_path,tmp,weight=0.6)
added_image_path=os.path.join(added_saved_dir,im_file)
mkdir(added_image_path)
cv2.imwrite(added_image_path,added_image)

#savepseudocolorprediction
pred_mask=utils.visualize.get_pseudo_color_map(tmp)
pred_saved_path=os.path.join(pred_saved_dir,
im_file.rsplit(".")[0]+".png")
mkdir(pred_saved_path)
pred_mask.save(pred_saved_path)

#pred_im=utils.visualize(im_path,pred,weight=0.0)
#pred_saved_path=os.path.join(pred_saved_dir,im_file)
#mkdir(pred_saved_path)
#cv2.imwrite(pred_saved_path,pred_im)

#progbar_pred.update(i+1)

if__name__=='__main__':
args=parse_args()
main(args)

第二个版本的predict.py将“裁剪遥感大图”和“拼接小图块的预测结果”封装成了函数,分别为CropBigImage(ImagePath,CropScale)PinJie(predict_list,CropScale,ori_image,h_step,w_step,h_rest,w_rest)。CropBigImage函数可以将遥感大图裁剪成任意尺寸的小图块,第二个版本的predict.py完整代码如下:

importsys
importargparse
importos
importpaddle
frompaddleseg.cvlibsimportmanager,Config
frompaddleseg.utilsimportget_sys_env,logger
importmath
importcv2
importnumpyasnp
frompaddlesegimportutils
frompaddleseg.coreimportinfer
frompaddleseg.utilsimportprogbar

defmkdir(path):
sub_dir=os.path.dirname(path)#去掉文件名,返回目录
ifnotos.path.exists(sub_dir):
os.makedirs(sub_dir)

defpartition_list(arr,m):
"""splitthelist'arr'intompieces"""
n=int(math.ceil(len(arr)/float(m)))
return[arr[i:i+n]foriinrange(0,len(arr),n)]

defparse_args():
parser=argparse.ArgumentParser(description='Modelprediction')

#paramsofprediction
parser.add_argument(
"--config",dest="cfg",help="Theconfigfile.",default=None,type=str)
parser.add_argument(
'--model_path',
dest='model_path',
help='Thepathofmodelforprediction',
type=str,
default=None)
parser.add_argument(
'--image_path',
dest='image_path',
help=
'Thepathofimage,itcanbeafileoradirectoryincludingimages',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='Thedirectoryforsavingthepredictedresults',
type=str,
default='./output/result')

#augmentforprediction
parser.add_argument(
'--aug_pred',
dest='aug_pred',
help='Whethertousemulit-scalesandflipaugmentforprediction',
action='store_true')
parser.add_argument(
'--scales',
dest='scales',
nargs='+',
help='Scalesforaugment',
type=float,
default=1.0)
parser.add_argument(
'--flip_horizontal',
dest='flip_horizontal',
help='Whethertousefliphorizontallyaugment',
action='store_true')
parser.add_argument(
'--flip_vertical',
dest='flip_vertical',
help='Whethertouseflipverticallyaugment',
action='store_true')

#slidingwindowprediction
parser.add_argument(
'--is_slide',
dest='is_slide',
help='Whethertopredictionbyslidingwindow',
action='store_true')
parser.add_argument(
'--crop_size',
dest='crop_size',
nargs=2,
help=
'Thecropsizeofslidingwindow,thefirstiswidthandthesecondisheight.',
type=int,
default=None)
parser.add_argument(
'--stride',
dest='stride',
nargs=2,
help=
'Thestrideofslidingwindow,thefirstiswidthandthesecondisheight.',
type=int,
default=None)

returnparser.parse_args()


defget_image_list(image_path):
"""Getimagelist"""
valid_suffix=[
'.JPEG','.jpeg','.JPG','.jpg','.BMP','.bmp','.PNG','.png','.tif'
]
image_list=[]
image_dir=None
ifos.path.isfile(image_path):
ifos.path.splitext(image_path)[-1]invalid_suffix:
image_list.append(image_path)
elifos.path.isdir(image_path):
image_dir=image_path
forroot,dirs,filesinos.walk(image_path):#root=image_path
forfinfiles:
ifos.path.splitext(f)[-1]invalid_suffix:
image_list.append(os.path.join(root,f))
else:
raiseFileNotFoundError(
'`--image_path`isnotfound.itshouldbeanimagefileoradirectoryincludingimages'
)

iflen(image_list)==0:
raiseRuntimeError('Therearenotimagefilein`--image_path`')

returnimage_list,image_dir#返回测试文件列表

defCropBigImage(ImagePath,CropScale):

ImagePath=ImagePath
CropScale=CropScale
seg_list=[]#存储分割的图块
ori_image=cv2.imread(ImagePath)##
h_step=ori_image.shape[0]//CropScale
w_step=ori_image.shape[1]//CropScale

h_rest=-(ori_image.shape[0]-CropScale*h_step)
w_rest=-(ori_image.shape[1]-CropScale*w_step)

#循环切图
forhinrange(h_step):
forwinrange(w_step):
#划窗采样
image_sample=ori_image[(h*CropScale):(h*CropScale+CropScale),
(w*CropScale):(w*CropScale+CropScale),:]
seg_list.append(image_sample)
seg_list.append(ori_image[(h*CropScale):(h*CropScale+CropScale),-CropScale:,:])
forwinrange(w_step-1):
seg_list.append(ori_image[-CropScale:,(w*CropScale):(w*CropScale+CropScale),:])
seg_list.append(ori_image[-CropScale:,-CropScale:,:])

returnseg_list,ori_image,h_step,w_step,h_rest,w_rest

defPinJie(predict_list,CropScale,ori_image,h_step,w_step,h_rest,w_rest):

#将预测后的图像块再拼接起来
count_temp=0
tmp=np.ones([ori_image.shape[0],ori_image.shape[1]])
forhinrange(h_step):
forwinrange(w_step):
tmp[
h*CropScale:(h+1)*CropScale,
w*CropScale:(w+1)*CropScale
]=predict_list[count_temp]
count_temp+=1
tmp[h*CropScale:(h+1)*CropScale,w_rest:]=predict_list[count_temp][:,w_rest:]
count_temp+=1
forwinrange(w_step-1):
tmp[h_rest:,(w*CropScale):(w*CropScale+CropScale)]=predict_list[count_temp][h_rest:,:]
count_temp+=1
tmp[-(CropScale+1):-1,-(CropScale+1):-1]=predict_list[count_temp][:,:]
returntmp.astype('uint8')

defmain(args):
env_info=get_sys_env()
place='gpu'ifenv_info['Paddlecompiledwithcuda']andenv_info[
'GPUsused']else'cpu'

paddle.set_device(place)
ifnotargs.cfg:
raiseRuntimeError('Noconfigurationfilespecified.')

cfg=Config(args.cfg)
val_dataset=cfg.val_dataset#用val_dataset?
ifnotval_dataset:
raiseRuntimeError(
'Theverificationdatasetisnotspecifiedintheconfigurationfile.'
)

msg='
---------------ConfigInformation---------------
'
msg+=str(cfg)
msg+='------------------------------------------------'
logger.info(msg)

model=cfg.model
transforms=val_dataset.transforms
#image_list,image_dir=get_image_list('data/UAV_seg/images')
image_list,image_dir=get_image_list(args.image_path)#需要传入args.image_path参数这个参数可以是测试图片的路径,也可以是单张图片的路径

model_path=args.model_path#传入训练模型的路径
save_dir=args.save_dir
aug_pred=args.aug_pred
scales=args.scales
flip_horizontal=args.flip_horizontal
flip_vertical=args.flip_vertical
is_slide=args.is_slide
crop_size=args.crop_size
stride=args.stride

para_state_dict=paddle.load(model_path)
model.set_dict(para_state_dict)
model.eval()
nranks=paddle.distributed.get_world_size()
local_rank=paddle.distributed.get_rank()
ifnranks>1:
img_lists=partition_list(image_list,nranks)
else:
img_lists=[image_list]#是列表还是列表的列表,等待测试img_lists[0]->列表的列表

added_saved_dir=os.path.join(save_dir,'added_prediction')#伪彩色和原图叠加
pred_saved_dir=os.path.join(save_dir,'pseudo_color_prediction')#伪彩色预测结果

#主要将遥感大图裁剪成固定尺寸的图块,生成图块列表
ImagePath=img_lists[local_rank][local_rank]
CropScale=256
seg_list,ori_image,h_step,w_step,h_rest,w_rest=CropBigImage(ImagePath,CropScale)

predict_list=[]
progbar_pred=progbar.Progbar(target=len(seg_list),verbose=1)
logger.info("Starttopredict...")
withpaddle.no_grad():
fori,iminenumerate(seg_list):
ori_shape=im.shape[:2]#原始图片形状(h,w)
im,_=transforms(im)#im.shape(3,512,512)_为None
im=im[np.newaxis,...]#im.shape(1,3,512,512)
im=paddle.to_tensor(im)

ifaug_pred:
pred=infer.aug_inference(
model,
im,
ori_shape=ori_shape,
transforms=transforms.transforms,
scales=scales,
flip_horizontal=flip_horizontal,
flip_vertical=flip_vertical,
is_slide=is_slide,
stride=stride,
crop_size=crop_size)
else:
pred=infer.inference(
model,
im,
ori_shape=ori_shape,
transforms=transforms.transforms,
is_slide=is_slide,
stride=stride,
crop_size=crop_size)
pred=paddle.squeeze(pred)#该OP会删除输入Tensor的Shape中尺寸为1的维度。查看pred的形状应该剩下[h,w]
pred=pred.numpy().astype('uint8')
predict_list.append(pred)
progbar_pred.update(i+1)

#主要将图块的预测结果拼接成大图
tmp=PinJie(predict_list,CropScale,ori_image,h_step,w_step,h_rest,w_rest)
#############
#获取需要保存的图片名称,去掉前面的路径
#getthesavedname
ifimage_dirisnotNone:
pass
#im_file=im_path.replace(image_dir,'')#例:将PaddleSeg/data/optic_disc_seg/JPEGImages/P0011.jpg替换为/P0011.jpg
else:
im_file=os.path.basename(img_lists[local_rank][local_rank])#带后缀名
ifim_file[0]=='/':#去掉/
im_file=im_file[1:]

#saveaddedimage
added_image=utils.visualize.visualize(args.image_path,tmp,weight=0.6)
added_image_path=os.path.join(added_saved_dir,im_file)
mkdir(added_image_path)
cv2.imwrite(added_image_path,added_image)

#savepseudocolorprediction
pred_mask=utils.visualize.get_pseudo_color_map(tmp)
pred_saved_path=os.path.join(pred_saved_dir,
im_file.rsplit(".")[0]+".png")
mkdir(pred_saved_path)
pred_mask.save(pred_saved_path)
logger.info("-"*30+"END"+"-"*30)

if__name__=='__main__':
args=parse_args()
main(args)

改写了predict.py源码文件,就要测试下它的效果,本文用了一张无人机遥感影像,目的是作物分类。如下图:
一、使用paddleseg套件对遥感影像预测(基础) python 第1张
为了减轻边缘效应和拼接痕迹,这里使用重叠度为50%的裁剪方式将原图裁剪成7000多张256x256的数据集,利用Unet网络对数据集进行训练,利用本文改写的predict.py对原图进行预测。运行predict.py代码参考如下:!pythonpredict.py--configunet-uav.yml--model_pathoutput/best_model/model.pdparams--image_path/home/aistudio/data/data70483/img.png,其中–config–model_path–image_path都是需要传入的参数,有这些参数但不仅限这些参数。预测结果图如下:
一、使用paddleseg套件对遥感影像预测(基础) python 第2张
一、使用paddleseg套件对遥感影像预测(基础) python 第3张
从语义分割的结果来看,感觉还不错,不过这里我的训练集和测试集是同一个数据集,所以并不能说明网络模型的泛化能力,只能说明网络模型的拟合能力还可以,但是本文目的已经达到了,就是对遥感影像(大图)预测。

各位小伙伴有任何问题可以在评论中留言,下一篇博文的内容依然是使用paddleseg套件对遥感影像预测,不过下篇博文的方法和以上代码有所差别,主要是做有重叠度裁剪待预测遥感大图和忽略相邻图块重叠部分做拼接,目的是为了减轻边缘效应和拼接痕迹,这对语义分割来说十分重要。

扫描关注公众号,第一时间获取网站更新动态

转载请说明来源于"114python之家"

本文地址:http://www.114python.com/post/9694.html

发表评论

评论列表

  • 这篇文章还没有收到评论,赶紧来抢沙发吧~