文档

图像分类训练和预测

更新时间:

简介

图像分类是指识别图片中主体或者状态单一的场景。

前提

在对图像分类进行训练之前,要准备好如下数据:

  • 开通OSS授权

  • 用于训练的图片集。

  • 图片集对应的标签。

操作步骤

下面将以JAVA SDK为例,详细描述如何训练自己的图像分类模型。操作步骤如下:

1.创建项目。核心示例代码:

        CreateProjectRequest request = new CreateProjectRequest();
        request.setName("图像分类测试");
        request.setDescription("图像分类描述");
        request.setProType("Classification");
        CreateProjectResponse response = client.getAcsResponse(request)
        // 保存项目ID。
        String projectId = response.getProject().getProjectId();

2.创建图片标签,每个项目最少需要两个以上的标签。例子中,将创建苹果和香蕉的标签。核心代码:

        CreateTagRequest request = new CreateTagRequest();
        // 创建项目时,返回的项目ID。
        request.setProjectId(projectId);
        request.setName("苹果");
        request.setDescription("苹果的描述");
        CreateTagResponse response = client.getAcsResponse(request);
        // 保存苹果的标签ID。
        String appleTagId = response.getTag().getTagId();

        request.setName("香蕉");
        request.setDescription("香蕉的描述");
        response = client.getAcsResponse(request);
        // 保存香蕉的标签ID。
        String bananaTagId = response.getTag().getTagId();

3.创建训练数据,并做标注。

  • 将苹果和香蕉的图片分别上传至OSS。

  • 通过创建训练数据接口将OSS文件添加到训练集,同时标注。例子里,核心代码:

          CreateTrainDatasFromUrlsRequest request = new CreateTrainDatasFromUrlsRequest();
          request.setProjectId(projectId);
          // 添加苹果训练数据。
          // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。
          request.setUrls("http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/1.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/2.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/3.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/4.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/5.jpg,");
          request.setTagId(appleTagId);
          CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request);
    
          // 添加香蕉训练数据。
          // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。
          request.setUrls("http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/1.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/2.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/3.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/4.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/5.jpg,");
          request.setTagId(bananaTagId);
          CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request);

4.开始训练数据准备完毕之后,调用开始训练接口,并等待训练完成。模型训练的时间较长,请耐心等待。核心代码:

        TrainProjectRequest request = new TrainProjectRequest();
        request.setProjectId(projectId);
        TrainProjectResponse response = client.getAcsResponse(request);
        // 保存迭代ID。
        String iterationId = response.getIterationId();

        // 等待迭代完成
        while(true) {
            DescribeTrainResultRequest request = new DescribeTrainResultRequest();
            request.setProjectId(projectId);
            request.setIterationId(iterationId);
            TrainProjectResponse response = client.getAcsResponse(request);
            if ("TrainSuccess".equals(response.getTrainResult().getStatus())) {
                break;
            }
            TimeUnit.SECONDS.sleep(5);
        }

5.预测图片训练结束之后,可以拿到训练时的迭代ID,进行预测。预测作业为异步接口,提交完成之后,通过查询接口进行查询预测结果。核心代码:

        // 提交图片预测。
        PredictImageRequest request = new PredictImageRequest();
        request.setProjectId(projectId);
        request.setIterationId(iterationId);
        // 替换成OSS图片URL
        request.setDataUrls("test-bucket.oss-cn-beijing.aliyuncs.com/predict/1.jpg");
        PredictImageResponse response = client.getAcsResponse(request);

        // 等待一会
        TimeUnit.SECONDS.sleep(10);

        // 查询结果
        List<PredictImageResponse.PredictData> datas = response.getPredictDatas();
        List<String> dataIds = new ArrayList<>(datas.size());
        for (PredictImageResponse.PredictData data : datas) {
            dataIds.add(data.getDataId());
        }
        DescribePredictDatasRequest request = new DescribePredictDatasRequest();
        request.setProjectId(projectId);
        request.setIterationId(iterationId);
        request.setDataIds(dataIds);
        PredictImageResponse response = client.getAcsResponse(request);
        // 输出预测结果
        System.out.println(JSON.toJSONString(response));
  • 本页导读 (0)
文档反馈