基于DataScience集群, 您可以使用Spark计算引擎进行大规模的图片Inference, 充分发挥各个Worker节点CPU或GPU的计算力,快速处理TB或PB级海量图片数据。
前提条件
背景信息
本文从以下场景为您介绍分布式Inference解决方案:
重要 DataScience已经可以运行CPU计算, 如果您需要运行GPU计算, 请联系产品运维人员升级Cuda至10.1版本。
图片存储在HDFS
本场景优点是,您可以使用DataScience集群自带的HDFS,无需单独购买其他存储,缺点是对于大量的小文件(例如几KB的图片),HDFS吞吐效率偏低。
您可以在IDE中查看DistributedPredictHDFS.java,代码如下。
public class DistributedPredictHDFS {
public static void main(String[] args) throws Exception{
System.out.println("Start DistributedPredictHDFS Job.");
String imagesPath = args[0];
String modelPath = args[1];
String resultPath = args[2];
SparkConf conf = new SparkConf().setAppName("DistributedPredictHDFSOnSpark");
JavaSparkContext sc = new JavaSparkContext(conf);
JavaPairRDD<String, PortableDataStream> imageStream = sc.binaryFiles(imagesPath, 128);
System.out.println("Partitions: "+ imageStream.getNumPartitions());
JavaRDD<String> result = imageStream.mapPartitions(new FlatMapFunction<Iterator<Tuple2<String, PortableDataStream>>, String>() {
private static final long serialVersionUID = 1L;
@Override
public Iterator<String> call(Iterator<Tuple2<String, PortableDataStream>> iterator) throws Exception{
ImageClassificationTranslator translator =
ImageClassificationTranslator.builder()
.addTransform(new Resize(224, 224))
.addTransform(new ToTensor())
.build();
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class) // defines input and output data type
.optTranslator(translator)
.optModelUrls(modelPath)
.build();
System.out.println("Enginename: "+Engine.getInstance().getEngineName());
System.out.println("ModelName: " + criteria.getModelName());
System.out.println("criteriainfo: " + criteria.toString());
List<String> list = new ArrayList<>();
ZooModel<Image, Classifications> model ;
Predictor<Image, Classifications> predictor;
model = ModelZoo.loadModel(criteria);
predictor = model.newPredictor() ;
int idx = 0;
List<Image> imagelist = new ArrayList<>();
while(iterator.hasNext()){
Tuple2<String, PortableDataStream> item = iterator.next();
String name = item._1();
PortableDataStream content = item._2();
Image img = ImageFactory.getInstance().fromInputStream(content.open());
imagelist.add(img);
idx++;
if(imagelist.size()%32 == 0) {
List<Classifications> results = predictor.batchPredict(imagelist);
System.out.println("index: " + idx + "" + name + " " + img.getWidth() + " " + img.getHeight() + " " + results.toString());
System.out.println("index: " + idx);
list.add(results.toString());
imagelist.clear();
}
}
/*
process_insufficient_images().
*/
return list.iterator();
}
});
result.saveAsTextFile(resultPath);
}
}
图片存储在HBase
重要 HBase需要2.0及以上版本,对应于E-MapReduce的4.x系列版本。
您可以在IDE中查看DistributedPredictHBase.java,代码如下。
public class DistributedPredictHBase {
public static void main(String[] args) throws Exception{
System.out.println("Start DistributedPredictHBase Job.");
String hbasePath = args[0];
String modelPath = args[1];
String resultPath = args[2];
SparkConf conf = new SparkConf().setAppName("DistributedPredictHBaseOnSpark");
JavaSparkContext sc = new JavaSparkContext(conf);
Scan scan = new Scan();
ClientProtos.Scan proto = ProtobufUtil.toScan(scan);
String scanToString = Base64.encodeBytes(proto.toByteArray());
/* 0 */
String tablename0 = "image0";
Configuration hbconf0 = HBaseConfiguration.create();
hbconf0.set(TableInputFormat.INPUT_TABLE, tablename0);
hbconf0.set(TableInputFormat.SCAN_BATCHSIZE, "256");
hbconf0.set(TableInputFormat.SCAN, scanToString);
hbconf0.set("hbase.zookeeper.quorum", hbasePath);
hbconf0.set("hbase.zookeeper.property.clientPort", "2181");
JavaPairRDD<ImmutableBytesWritable, Result> HBaseRdd0 = sc.newAPIHadoopRDD(hbconf0, TableInputFormat.class,
ImmutableBytesWritable.class, Result.class);
for(int i=1;i<8;i++) {
/* 1 */
String tablename = "image" + i;
Configuration hbconf = HBaseConfiguration.create();
hbconf.set(TableInputFormat.INPUT_TABLE, tablename);
hbconf.set(TableInputFormat.SCAN_BATCHSIZE, "256");
hbconf.set(TableInputFormat.SCAN, scanToString);
hbconf.set("hbase.zookeeper.quorum", hbasePath);
hbconf.set("hbase.zookeeper.property.clientPort", "2181");
JavaPairRDD<ImmutableBytesWritable, Result> HBaseRdd = sc.newAPIHadoopRDD(hbconf, TableInputFormat.class,
ImmutableBytesWritable.class, Result.class);
HBaseRdd0 = HBaseRdd0.union(HBaseRdd);
}
System.out.println("Partitions: "+ HBaseRdd0.getNumPartitions());
JavaRDD<String> resultx = HBaseRdd0.mapPartitions(new FlatMapFunction<Iterator<Tuple2<ImmutableBytesWritable, Result>>, String>() {
private static final long serialVersionUID = 1L;
@Override
public Iterator<String> call(Iterator<Tuple2<ImmutableBytesWritable, Result>> iterator) throws Exception {
// TODO Auto-generated method stub
ImageClassificationTranslator translator =
ImageClassificationTranslator.builder()
.addTransform(new Resize(224, 224))
.addTransform(new ToTensor())
.build();
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class) // defines input and output data type
.optTranslator(translator)
.optModelUrls(modelPath)
.build();
System.out.println("Enginename: "+Engine.getInstance().getEngineName());
System.out.println("ModelName: " + criteria.getModelName());
System.out.println("criteriainfo: " + criteria.toString());
List<String> list = new ArrayList<>();
ZooModel<Image, Classifications> model ;
Predictor<Image, Classifications> predictor;
model = ModelZoo.loadModel(criteria);
predictor = model.newPredictor() ;
int idx = 0;
List<Image> imagelist = new ArrayList<>();
List<String> rows = new ArrayList<String>();
while (iterator.hasNext()) {
Result result = iterator.next()._2();
String rowKey = Bytes.toString(result.getRow());
byte[] body = result.getValue("f".getBytes(), "body".getBytes());
InputStream input = new ByteArrayInputStream(body);
Image img = ImageFactory.getInstance().fromInputStream(input);
imagelist.add(img);
idx++;
if(imagelist.size()%64 == 0) {
List<Classifications> results = predictor.batchPredict(imagelist);
//System.out.println("index: " + idx + " " + img.getWidth() + " " + img.getHeight() + " " + results.toString());
System.out.println("index: " + idx + " " + rowKey);
results.clear();
imagelist.clear();
}
rows.add(rowKey);
}
/*
process_insufficient_images().
*/
return rows.iterator();
}
});
resultx.saveAsTextFile(resultPath);
}
}
图片存储在OSS
您可以在IDE中查看DistributedPredictOSS.java,代码如下。
public class DistributedPredictOSS {
public static void main(String[] args) throws Exception{
System.out.println("Start DistributedPredictOSS Job.");
String ossPath = args[0];
String modelPath = args[1];
String resultPath = args[2];
SparkConf conf = new SparkConf().setAppName("DistributedPredictOSSOnSpark");
JavaSparkContext sc = new JavaSparkContext(conf);
JavaPairRDD<String, PortableDataStream> imageStream = sc.binaryFiles(ossPath, 128);
System.out.println("Partitions: "+ imageStream.getNumPartitions());
JavaRDD<String> result = imageStream.mapPartitions(new FlatMapFunction<Iterator<Tuple2<String, PortableDataStream>>, String>() {
private static final long serialVersionUID = 1L;
@Override
public Iterator<String> call(Iterator<Tuple2<String, PortableDataStream>> iterator) throws Exception{
ImageClassificationTranslator translator =
ImageClassificationTranslator.builder()
.addTransform(new Resize(224, 224))
.addTransform(new ToTensor())
.build();
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class) // defines input and output data type
.optTranslator(translator)
.optModelUrls(modelPath)
.build();
System.out.println("Enginename: "+Engine.getInstance().getEngineName());
System.out.println("ModelName: " + criteria.getModelName());
System.out.println("criteriainfo: " + criteria.toString());
List<String> list = new ArrayList<>();
ZooModel<Image, Classifications> model ;
Predictor<Image, Classifications> predictor;
model = ModelZoo.loadModel(criteria);
predictor = model.newPredictor() ;
int idx = 0;
List<Image> imagelist = new ArrayList<>();
while(iterator.hasNext()){
Tuple2<String, PortableDataStream> item = iterator.next();
String name = item._1();
PortableDataStream content = item._2();
Image img = ImageFactory.getInstance().fromInputStream(content.open());
imagelist.add(img);
idx++;
if(imagelist.size()%32 == 0) {
List<Classifications> results = predictor.batchPredict(imagelist);
//System.out.println("index: " + idx + "" + name + " " + img.getWidth() + " " + img.getHeight() + " " + results.toString());
System.out.println("index: " + idx);
list.add(results.toString());
imagelist.clear();
}
}
/*
process_insufficient_images().
*/
return list.iterator();
}
});
result.saveAsTextFile(resultPath);
}
}
多图片合并成大文件存储在HDFS
本场景优点是吞吐率高,缺点是需要您先使用ConvertImageToBase64工具,将图片转成Base64编码, 合并成一个大文件存储于HDFS。
您可以在IDE中查看DistributedPredictHDFSBigFile.java,代码如下。
public class DistributedPredictHDFSBigFile {
public static void main(String[] args) throws Exception{
System.out.println("Start DistributedPredictHDFSBigFile Job.");
String imagesPath = args[0];
String modelPath = args[1];
String resultPath = args[2];
SparkConf conf = new SparkConf().setAppName("DistributedPredictHDFSBigFileOnSpark");
JavaSparkContext sc = new JavaSparkContext(conf);
JavaRDD<String> imageStream_base64 = sc.textFile(imagesPath);
System.out.println("Partitions: "+ imageStream_base64.getNumPartitions());
JavaRDD<byte[]> imageStream_bytes = imageStream_base64.map(new Function<String, byte[]>() {
@Override
public byte[] call(String in) throws Exception {
byte[] out = Base64.getDecoder().decode(in);
return out;
}
});
System.out.println("Partitions: "+ imageStream_bytes.getNumPartitions());
JavaRDD<String> result = imageStream_bytes.mapPartitions(new FlatMapFunction<Iterator<byte[]>, String>() {
private static final long serialVersionUID = 1L;
@Override
public Iterator<String> call(Iterator<byte[]> iterator) throws Exception{
ImageClassificationTranslator translator =
ImageClassificationTranslator.builder()
.addTransform(new Resize(224, 224))
.addTransform(new ToTensor())
.build();
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class) // defines input and output data type
.optTranslator(translator)
.optModelUrls(modelPath)
.build();
System.out.println("Enginename: "+Engine.getInstance().getEngineName());
System.out.println("ModelName: " + criteria.getModelName());
System.out.println("criteriainfo: " + criteria.toString());
List<String> list = new ArrayList<>();
ZooModel<Image, Classifications> model ;
Predictor<Image, Classifications> predictor;
model = ModelZoo.loadModel(criteria);
predictor = model.newPredictor() ;
int idx = 0;
List<Image> imagelist = new ArrayList<>();
while(iterator.hasNext()){
byte[] body = iterator.next();
InputStream input = new ByteArrayInputStream(body);
Image img = ImageFactory.getInstance().fromInputStream(input);
imagelist.add(img);
idx++;
if(imagelist.size()%32 == 0) {
List<Classifications> results = predictor.batchPredict(imagelist);
System.out.println("index: " + idx + " " + img.getWidth() + " " + img.getHeight() + " " + results.toString());
System.out.println("index: " + idx);
list.add(results.toString());
imagelist.clear();
}
}
/*
process_insufficient_images().
*/
return list.iterator();
}
});
result.saveAsTextFile(resultPath);
}
}
图片转BASE64
您可以在IDE中查看ConvertImageToBase64.java,代码如下。
public class ConvertImageToBase64 {
public static void main(String[] args) throws Exception {
String filename = "car.jpg";
File file = new File(filename);
FileInputStream fis = new FileInputStream(file);
byte[] fileBytes = new byte[(int) file.length()];
fis.read(fileBytes);
String encoded = Base64.getEncoder().encodeToString(fileBytes);
/* decode */
byte[] decoded = Base64.getDecoder().decode(encoded);
encoded += '\n';
OutputStream out = new BufferedOutputStream(new FileOutputStream(filename+".base64", false));
out.write(encoded.getBytes());
}
}
导入图片至HBase
HBase提供高性能的Key-Value存储,可以把海量图片存储到HBase里,显著提升了HDFS的IO效率。
您可以在IDE中查看ImportImageToHBase.java,代码如下。
public class ImportImageToHBase {
public static void main(String[] args) throws Exception {
String h_table = args[0];
String filename = "car.jpg";
Configuration configuration = HBaseConfiguration.create();
configuration.set("hbase.zookeeper.quorum", "192.168.0.*:2181");
Connection connection = ConnectionFactory.createConnection(configuration);
Table table = connection.getTable(TableName.valueOf(h_table));
/* put image into hbase*/
File file = new File(filename);
FileInputStream fis = new FileInputStream(file);
byte[] fileBytes = new byte[(int) file.length()];
fis.read(fileBytes);
int i;
for(i=0;i<100000;i++) {
String key = String.valueOf(i) + ".jpg";
Put put = new Put(key.getBytes());
put.addColumn("f".getBytes(), "body".getBytes(), fileBytes);
try {
table.put(put);
} catch (IOException e) {
e.printStackTrace();
}
}
fis.close();
System.out.println("put 100000 images done!");
/* get image from hbase*/
Get get = new Get(Bytes.toBytes(filename));
Result result = table.get(get);
byte[] body = result.getValue("f".getBytes(), "body".getBytes());
OutputStream out = new BufferedOutputStream(new FileOutputStream(filename+".bk", false));
out.write(body);
}
}