I’m trying to run Sagemaker Transform job with the following configuration:
BatchStrategy="MultiRecord" and "JoinSource": "Input"
using the – content-type='text/csv'
and accept='text/csv'
However I’m getting an error
[sagemaker logs]: my-path-to-s3/batch/input/sample_batch_transform_payload.csv:
Fail to join data: mismatched line count between the input and the output
When I run the code with
BatchStrategy=”SingleRecord”
everything works.
My model outputs are embeddings of 1024 so
the predictions are similar to the below array (assuming that my input csv had 3 rows)
import numpy as np
predictions = np.array([
np.random.randn(1024),
np.random.randn(1024),
np.random.randn(1024)
])
print(predictions.shape)
# (3, 1024)
I’ve tried various outputs for the output_fn
function but all failed with the same above error
Here is how AWS sagemaker-inference toolkit implements the output but I still get the same error
from io import StringIO
import numpy as np
def output_fn(predictions, accept="text/csv"):
stream = StringIO()
np.savetxt(stream, predictions, delimiter=",", fmt="%s")
csv_output = stream.getvalue()
return csv_output
Here is the configuration of the Transform job:
{
'TransformJobName': 'ai-manual-transform-job-2024-05-28T10-35-50',
'TransformJobArn': '<MY-ARN>',
'TransformJobStatus': 'Failed',
'FailureReason': 'ClientError: See job logs for more information',
'ModelName': '<MY-MODEL>',
'MaxConcurrentTransforms': 1,
'MaxPayloadInMB': 6,
'BatchStrategy': 'MultiRecord',
'TransformInput': {'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix',
'S3Uri': '<PATH-To-S3>'}},
'ContentType': 'text/csv',
'CompressionType': 'None',
'SplitType': 'Line'},
'TransformOutput': {'S3OutputPath': '<PATH-to-s3-output>',
'Accept': 'text/csv',
'AssembleWith': 'Line',
'KmsKeyId': ''},
'TransformResources': {'InstanceType': 'ml.m5.large', 'InstanceCount': 1},
'CreationTime': datetime.datetime(2024, 5, 28, 10, 35, 51, 145000, tzinfo=tzlocal()),
'TransformStartTime': datetime.datetime(2024, 5, 28, 10, 39, 35, 16000, tzinfo=tzlocal()),
'TransformEndTime': datetime.datetime(2024, 5, 28, 10, 42, 32, 743000, tzinfo=tzlocal()),
'DataProcessing': {'InputFilter': '$[4]',
'OutputFilter': '$',
'JoinSource': 'Input'},
Any suggestions would be appreciated.