The task: given two sequences a
and b
of length n
. I need to find amount of the pairs 1 <= l <= r <= n
such that max(a[l], a[l+1], ..., a[r]) = min(a[l], a[l+1], ..., a[r])
.
I am using Sparse Table for min
and max
and I have this code: (straightforward method is for testing):
import java.io.*;
import java.util.StringTokenizer;
public class task {
public static int log(int x) {
return 31 - Integer.numberOfLeadingZeros(x);
}
public static int[][] buildSparseTable(int[] in, boolean isMin) {
int n = in.length;
int l = log(n) + 1;
int[][] res = new int[n][l];
for (int i = 0; i < n; i++) {
res[i][0] = in[i];
}
for (int j = 1; (1 << j) <= n; j++) {
for (int i = 0; (i + (1 << j) - 1) < n; i++) {
int nextIndex = i + (1 << (j - 1));
if (isMin) {
res[i][j] = Math.min(res[i][j - 1], res[nextIndex][j - 1]);
} else {
res[i][j] = Math.max(res[i][j - 1], res[nextIndex][j - 1]);
}
}
}
return res;
}
public static int query(int[][] st, int l, int r, boolean isMin) {
int j = log(r - l + 1);
if (isMin) {
return Math.min(st[l][j], st[r - (1 << j) + 1][j]);
} else {
return Math.max(st[l][j], st[r - (1 << j) + 1][j]);
}
}
public static void main(String[] args) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
st = new StringTokenizer(reader.readLine());
int n = Integer.parseInt(st.nextToken());
int[] a = new int[n];
int[] b = new int[n];
st = new StringTokenizer(reader.readLine());
for (int i = 0; i < n; i++) {
a[i] = Integer.parseInt(st.nextToken());
}
st = new StringTokenizer(reader.readLine());
for (int i = 0; i < n; i++) {
b[i] = Integer.parseInt(st.nextToken());
}
reader.close();
int[][] A = buildSparseTable(a, false);
int[][] B = buildSparseTable(b, true);
int testans = 0;
for (int i = 0; i < n; i++) {
int minB = Integer.MAX_VALUE;
int maxA = Integer.MIN_VALUE;
for (int j = i; j < n; j++) {
int t = log(j - i + 1);
minB = Math.min(B[i][t], B[j - (1 << t) + 1][t]);
maxA = Math.max(A[i][t], A[j - (1 << t) + 1][t]);
if (minB == maxA) {
testans++;
}
}
}
System.out.println("test == " + testans);
int answer = 0;
for (int i = 0; i < n; i++) {
int l = i;
int r = n - 1;
while (l <= r) {
int m = (l + r) / 2;
int a_t = query(A, i, m, false);
int b_t = query(B, i, m, true);
if (a_t == b_t) {
answer++;
r = m - 1;
} else if (b_t > a_t) {
l = m + 1;
} else {
r = m - 1;
}
}
}
BufferedWriter output = new BufferedWriter(new OutputStreamWriter(System.out));
output.write(answer + "n");
output.flush();
output.close();
}
}
But it fails on some tests and I can’t understand why (all the tests I created gave the right answers)…