Android Java와 JNI(libtorch)에서 forward 속도가 차이가 나는데요.

안녕하세요.
Pytorch Mobile 이용해서 android에서 테스트를 하고 있습니다.
훈련된 모델을 optimize_for_mobile()를 통해 pt파일을 생성하여 forward(inference) 하고 있습니다.
그런데, forward() JAVA API를 통해 inference 속도보다 JNI <torch/script.h>의 forward()가 1.5배 느리네요.
제가 놓치고 있는 부분이 있을까요?
참고로, pytorch libraray 연동은 gradle에서 아래의 같이 사용하였습니다.

implementation 'org.pytorch:pytorch_android:1.10.0'

<JAVA 소스 예시>

Tensor inputTensor = Tensor.fromBlob(inTensorBuffer, new long[]{1, 3, SIZE, SIZE});
final long startTime = SystemClock.elapsedRealtime();
Tensor outTensor = module.forward(IValue.from(inputTensor)).toTensor();
final long inferenceTime = SystemClock.elapsedRealtime() - startTime;

<JNI 소스 예시>

  JITCallGuard guard;
  torch::jit::Module module = torch::jit::load(modelPath);
  module.eval();
  torch::Tensor t = torch::randn({1, 3, 224, 224});
  log("input tensor:", t);
  c10::IValue t_out = module.forward({t});
  log("output tensor:", t_out);
  env->ReleaseStringUTFChars(jModelPath, modelPath);
  • 아래 답변은 Anthropic의 Claude-2 모델로 자동 생성한 것으로, 정확하지 않을 수 있습니다.
  • 잘못된 답변을 보셨다면 댓글로 첨삭 지도 부탁드립니다. :sweat_smile:

안녕하세요 Mr.Native 님,

게시글의 내용을 잘 읽고 문제를 이해했습니다. Android Java와 JNI(libtorch)를 사용하여 Pytorch 모델을 호출할 때 JNI 쪽이 Java보다 inference 속도가 1.5배 정도 느리다고 하셨습니다.

이는 JNI와 Java 간의 성능 차이에서 기인할 수 있습니다. JNI는 native 코드와 Java 코드 간 데이터 전달로 인한 오버헤드가 발생하고, Java 쪽에서는 JIT 컴파일러에 의한 최적화 효과가 있기 때문입니다.

다만 정확한 원인을 파악하기 위해서는 좀 더 자세한 벤치마킹과 프로파일링이 필요할 것 같습니다.

제가 추천드리고 싶은 내용은 다음과 같습니다:

  1. 입력 데이터 크기, batch size 등을 달리하며 benchmark 수행
  2. Java와 JNI 측 코드 분석을 통한 병목 지점 찾기
  3. Profiler를 이용한 호출 시간 측정
  4. Android 프로파일링 도구 활용

위 내용을 바탕으로 좀 더 정확한 성능 차이 원인을 파악할 수 있을 것입니다.

추가로 질문이나 도움이 필요한 부분이 있다면 언제든지 댓글 남기세요. 감사합니다.


  • 답변이 어떠셨나요? 괜찮았다고 생각하신다면 아래 :heart:를 한 번 눌러주시길 부탁드려요!
  • 아래 답변은 OpenAI의 GPT-4-Turbo 모델로 자동 생성한 것으로, 정확하지 않을 수 있습니다.
  • 잘못된 답변을 보셨다면 댓글로 첨삭 지도 부탁드립니다. :sweat_smile:

안녕하세요!

문제 상황을 정리해보겠습니다. PyTorch Mobile을 사용하여 Android에서 모델을 실행하고 있으며, optimize_for_mobile()을 통해 최적화된 모델을 사용하여 추론(inference)을 진행하고 계십니다. 그런데 Java API를 사용한 forward() 호출이 JNI를 통한 <torch/script.h>forward() 호출보다 약 1.5배 빠른 것을 발견하셨군요. 이에 대해 왜 그런지, 혹시 놓치고 있는 부분이 있는지 궁금해 하십니다.

먼저, JNI를 통한 추론이 더 느릴 수 있는 몇 가지 이유를 살펴보겠습니다:

  1. 초기화 오버헤드: JNI를 사용할 때는 Java와 네이티브 코드 간의 컨텍스트 전환(context switching)이 필요합니다. 이 과정에서 추가적인 오버헤드가 발생할 수 있습니다.

  2. 메모리 복사: JNI를 사용하면 Java와 네이티브 코드 간에 데이터를 전달할 때 메모리 복사가 발생할 수 있습니다. 이는 성능 저하를 일으킬 수 있습니다.

  3. 최적화 수준: PyTorch Mobile은 Android에서 실행될 때 특정 최적화를 적용할 수 있습니다. JNI를 통해 직접 네이티브 라이브러리를 사용할 경우, 이러한 최적화가 적용되지 않을 수 있습니다.

  4. 스레딩: JNI를 사용하는 경우, 네이티브 코드에서 스레딩을 어떻게 관리하는지에 따라 성능 차이가 발생할 수 있습니다. PyTorch Mobile은 내부적으로 스레딩을 최적화할 수 있지만, 직접 JNI를 사용할 때는 이러한 최적화를 수동으로 관리해야 할 수 있습니다.

이러한 가능성을 염두에 두고, 성능 차이를 줄이기 위해 다음과 같은 점들을 확인해볼 수 있습니다:

  • JNI 코드에서 모델을 로드하고 추론을 실행하기 전에 필요한 초기화 작업이 모두 완료되었는지 확인합니다.
  • 네이티브 코드에서 사용하는 텐서의 메모리 할당과 해제가 효율적으로 이루어지고 있는지 검토합니다.
  • 네이티브 라이브러리의 컴파일 옵션에 최적화 플래그가 적절히 설정되어 있는지 확인합니다.
  • 멀티스레딩이 필요한 경우, 네이티브 코드에서 스레드 풀을 사용하여 스레딩을 관리하고 있는지 확인합니다.

아래는 JNI 코드의 성능을 개선하기 위해 참고할 수 있는 예시 코드입니다:

// JNI 코드에서 모델을 로드하고 추론을 실행하는 예시
extern "C" JNIEXPORT jfloatArray JNICALL
Java_com_example_myapp_MyModel_forward(JNIEnv *env, jobject thiz, jstring jModelPath) {
    // 모델 경로를 가져옵니다.
    const char *modelPath = env->GetStringUTFChars(jModelPath, nullptr);

    // 모델을 로드합니다.
    torch::jit::Module module = torch::jit::load(modelPath);
    module.eval(); // 평가 모드로 설정합니다.

    // 입력 텐서를 생성합니다.
    torch::Tensor inputTensor = torch::randn({1, 3, 224, 224});

    // 추론을 실행합니다.
    auto outputs = module.forward({inputTensor}).toTensor();

    // 결과를 Java 배열로 변환합니다.
    jfloatArray resultArray = env->NewFloatArray(outputs.numel());
    env->SetFloatArrayRegion(resultArray, 0, outputs.numel(), outputs.data_ptr<float>());

    // 사용한 자원을 해제합니다.
    env->ReleaseStringUTFChars(jModelPath, modelPath);

    return resultArray;
}

이 코드는 JNI를 사용하여 모델을 로드하고 추론을 실행하는 기본적인 예시입니다. 실제 사용 시에는 입력 데이터를 Java에서 네이티브 코드로 전달하고, 결과를 네이티브 코드에서 Java로 반환하는 과정이 추가로 필요합니다.

추가적인 정보와 PyTorch Mobile에 대한 자세한 내용은 PyTorch 공식 문서를 참고하시기 바랍니다:

제 답변이 정확하지 않을 수도 있으니, 추가적인 질문이나 정보가 필요하시면 언제든지 문의해 주세요.


  • 답변이 어떠셨나요? 괜찮았다고 생각하신다면 아래 :heart:를 한 번 눌러주시길 부탁드려요!
1개의 좋아요
  • 아래 답변은 OpenAI의 GPT-4 모델로 자동 생성한 것으로, 정확하지 않을 수 있습니다.
  • 잘못된 답변을 보셨다면 댓글로 첨삭 지도 부탁드립니다. :sweat_smile:

안녕하세요!

문제를 겪고 계신 것 같아 안타깝습니다. PyTorch Mobile을 사용하여 Android에서 모델을 실행하는 과정에서 Java API를 통한 inference 속도와 JNI를 사용하여 C++의 libtorch를 통한 inference 속도 사이에 차이가 있다는 점을 이해했습니다. Java API를 사용할 때보다 JNI를 사용하여 libtorch의 forward()를 호출했을 때 속도가 약 1.5배 느리다는 것이 문제의 핵심인 것 같습니다.

속도 차이의 원인을 파악하기 위해서는 몇 가지 가능성을 고려해볼 수 있습니다:

  1. 스레딩과 병렬 처리: JNI를 사용할 때 C++ 코드가 어떻게 실행되는지, 스레드가 어떻게 관리되는지 확인해야 합니다. Java API는 내부적으로 최적화된 스레딩 모델을 사용할 수 있지만, JNI를 통해 직접 C++ 코드를 실행할 때는 이러한 최적화가 적용되지 않을 수 있습니다.

  2. 모델 최적화: optimize_for_mobile()을 사용하여 모델을 최적화했지만, JNI를 통해 모델을 로드할 때 이 최적화가 제대로 적용되었는지 확인해야 합니다.

  3. 데이터 복사와 변환: Java와 JNI 사이의 데이터 전달 과정에서 불필요한 데이터 복사가 발생할 수 있습니다. 이는 성능 저하를 일으킬 수 있습니다.

  4. 초기화 시간: JNI를 사용할 때 모듈을 로드하고 초기화하는 데 추가 시간이 소요될 수 있습니다. 이는 첫 번째 inference 시간에 영향을 줄 수 있습니다.

  5. 빌드 설정: PyTorch 라이브러리와 관련된 빌드 설정이 최적화되어 있는지 확인해야 합니다. 예를 들어, 릴리스 모드로 빌드했는지, 최적화 플래그가 적절히 설정되었는지 등을 검토해야 합니다.

이러한 가능성을 고려하여 문제를 해결하기 위한 몇 가지 접근 방법을 제안드립니다:

  1. 프로파일링: Java API와 JNI를 통한 실행 시간을 비교하기 위해 프로파일링을 수행합니다. 이를 통해 성능 저하의 원인을 좀 더 정확히 파악할 수 있습니다.

  2. 최적화 확인: JNI를 사용하여 모델을 로드할 때 optimize_for_mobile()을 통해 최적화된 모델이 제대로 로드되고 있는지 확인합니다.

  3. 데이터 전달 최적화: Java와 JNI 사이의 데이터 전달을 최소화하고, 필요한 경우 NIO 버퍼를 사용하여 데이터를 전달하는 방법을 고려합니다.

  4. 빌드 설정 검토: C++ 코드를 빌드할 때 최적화 플래그가 설정되어 있는지 확인하고, 필요한 경우 빌드 설정을 조정합니다.

아래는 JNI에서 모델을 로드하고 실행하는 예시 코드를 최적화하는 방법에 대한 간단한 가이드입니다:

#include <torch/script.h>
#include <torch/torch.h>

// JNI 함수 내부에서 모델을 로드하고 실행하는 예시입니다.
extern "C"
JNIEXPORT void JNICALL
Java_com_example_myapp_MyModel_runModel(JNIEnv *env, jobject thiz, jstring jModelPath) {
    // 모델 경로를 JNI String에서 C++ string으로 변환합니다.
    const char *modelPath = env->GetStringUTFChars(jModelPath, nullptr);

    // 모델을 로드합니다. 최적화된 모델인지 확인하세요.
    torch::jit::Module module;
    try {
        module = torch::jit::load(modelPath);
        module.eval(); // 모델을 평가 모드로 설정합니다.
    } catch (const c10::Error &e) {
        // 모델 로드 실패 시 에러 처리를 합니다.
        env->ReleaseStringUTFChars(jModelPath, modelPath);
        return;
    }

    // 입력 텐서를 생성합니다.
    torch::Tensor inputTensor = torch::randn({1, 3, 224, 224});

    // 모델을 실행합니다.
    at::Tensor output = module.forward({inputTensor}).toTensor();

    // 결과를 처리합니다. (예: 로깅)

    // 사용이 끝난 문자열 리소스를 해제합니다.
    env->ReleaseStringUTFChars(jModelPath, modelPath);
}

이 코드는 JNI를 사용하여 모델을 로드하고 실행하는 기본적인 예시입니다. 실제 사용 시에는 입력 데이터를 Java에서 받아와서 처리하고, 결과를 Java로 반환하는 등의 추가 작업이 필요할 수 있습니다.

더 자세한 정보와 PyTorch Mobile에 대한 최적화 팁은 PyTorch 공식 문서를 참조하시기 바랍니다:

제 답변이 정확하지 않을 수도 있으니, 문제 해결에 도움이 되었으면 좋겠습니다. 추가적인 질문이나 정보가 필요하시면 언제든지 문의해 주세요.


  • 답변이 어떠셨나요? 괜찮았다고 생각하신다면 아래 :heart:를 한 번 눌러주시길 부탁드려요!