I am trying to prevent near-duplicate images in my system. But I need to check each incoming image against all existing ones. So it is not 1-1 similarity check. And sometimes images are not the same but almost identical. For example see these 2 images (note the slight brightness difference):
- https://pbs.twimg.com/amplify_video_thumb/1864293853525684224/img/Xj98d6CgNDezU3SE.jpg
- https://pbs.twimg.com/amplify_video_thumb/1864298609203671040/img/LxQ6a9wx9npYpGAQ.jpg
Therefore I tried to generate hash from the image which is stored in database and each incoming image is checked against database.
I implemented a solution in Java and which works fine but it consumes too much CPU. I am running it on 8CPU droplet and expected process rate is 100 reqs/second. But current solution reaches 100% CPU usage with 30-40 req/second and is not acceptable.
Do you have any suggestion for simpler process or should I create a separate service in Python to process image hash?
import java.awt.*;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.IOException;
import java.net.URI;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.imageio.ImageIO;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class ImageUtil {
private static final int HASH_SIZE = 16;
private static final int INNER_SIZE = 8;
private static final int OFFSET = 4;
public static List<ImageHash> computeHashesBatch(List<String> urls) {
return urls.parallelStream()
.map(url -> {
try {
return new ImageHash(url, computePerceptualHash(url));
} catch (IOException e) {
return null;
}
})
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
public static String computePerceptualHash(String url) throws IOException {
BufferedImage img = ImageIO.read(URI.create(url).toURL());
if (img == null) {
return null;
}
String hash = computeHashFromImage(img);
return String.format("%s%s%s-%s", img.getType(), img.getWidth(), img.getHeight(), hash);
}
private static String computeHashFromImage(BufferedImage img) {
// Convert and resize
BufferedImage processed = new BufferedImage(HASH_SIZE, HASH_SIZE, BufferedImage.TYPE_BYTE_GRAY);
Graphics2D g = processed.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(img, 0, 0, HASH_SIZE, HASH_SIZE, null);
g.dispose();
byte[] pixels = ((DataBufferByte) processed.getRaster().getDataBuffer()).getData();
char[] hashBits = new char[INNER_SIZE * INNER_SIZE];
int hashIdx = 0;
long totalGradient = 0;
for (int y = OFFSET; y < OFFSET + INNER_SIZE; y++) {
int rowOffset = y * HASH_SIZE;
for (int x = OFFSET; x < OFFSET + INNER_SIZE; x++) {
int idx = rowOffset + x;
int gx = (pixels[idx + 1] & 0xFF) - (pixels[idx - 1] & 0xFF);
int gy = (pixels[idx + HASH_SIZE] & 0xFF) - (pixels[idx - HASH_SIZE] & 0xFF);
int gradient = Math.abs(gx) + Math.abs(gy);
totalGradient += gradient;
hashBits[hashIdx++] = (char) gradient;
}
}
int avgGradient = (int) (totalGradient / (INNER_SIZE * INNER_SIZE));
StringBuilder result = new StringBuilder(16);
int accumulator = 0;
int bitCount = 0;
for (char hashBit : hashBits) {
accumulator = (accumulator << 1) | (hashBit > avgGradient ? 1 : 0);
bitCount++;
if (bitCount == 4) {
result.append(Integer.toHexString(accumulator));
accumulator = 0;
bitCount = 0;
}
}
if (bitCount > 0) {
accumulator <<= (4 - bitCount);
result.append(Integer.toHexString(accumulator));
}
return result.toString();
}
@Getter
@AllArgsConstructor
public static class ImageHash {
private String url;
private String hash;
}
}
Test code
@Test
void test() {
long start = System.currentTimeMillis();
List<String> urls = Arrays.asList(
"https://pbs.twimg.com/amplify_video_thumb/1864293853525684224/img/Xj98d6CgNDezU3SE.jpg",
"https://pbs.twimg.com/amplify_video_thumb/1864298609203671040/img/LxQ6a9wx9npYpGAQ.jpg"
);
List<ImageUtil.ImageHash> hashes = ImageUtil.computeHashesBatch(urls);
hashes.forEach(h -> log.info("{} - {}", h.getHash(), h.getUrl()));
log.info("Duration: {}ms", System.currentTimeMillis() - start);
}
6