|
4 | 4 | import re |
5 | 5 | import shutil |
6 | 6 | import pytest |
| 7 | +from unittest.mock import MagicMock |
7 | 8 |
|
8 | 9 | from markitdown._uri_utils import parse_data_uri, file_uri_to_path |
9 | 10 |
|
@@ -370,6 +371,50 @@ def test_markitdown_exiftool() -> None: |
370 | 371 | assert target in result.text_content |
371 | 372 |
|
372 | 373 |
|
| 374 | +def test_markitdown_llm_parameters() -> None: |
| 375 | + """Test that LLM parameters are correctly passed to the client.""" |
| 376 | + mock_client = MagicMock() |
| 377 | + mock_response = MagicMock() |
| 378 | + mock_response.choices = [ |
| 379 | + MagicMock( |
| 380 | + message=MagicMock( |
| 381 | + content="Test caption with red circle and blue square 5bda1dd6" |
| 382 | + ) |
| 383 | + ) |
| 384 | + ] |
| 385 | + mock_client.chat.completions.create.return_value = mock_response |
| 386 | + |
| 387 | + test_prompt = "You are a professional test prompt." |
| 388 | + markitdown = MarkItDown( |
| 389 | + llm_client=mock_client, llm_model="gpt-4o", llm_prompt=test_prompt |
| 390 | + ) |
| 391 | + |
| 392 | + # Test image file |
| 393 | + markitdown.convert(os.path.join(TEST_FILES_DIR, "test_llm.jpg")) |
| 394 | + |
| 395 | + # Verify the prompt was passed to the OpenAI API |
| 396 | + assert mock_client.chat.completions.create.called |
| 397 | + call_args = mock_client.chat.completions.create.call_args |
| 398 | + messages = call_args[1]["messages"] |
| 399 | + assert len(messages) == 1 |
| 400 | + assert messages[0]["content"][0]["text"] == test_prompt |
| 401 | + |
| 402 | + # Reset the mock for the next test |
| 403 | + mock_client.chat.completions.create.reset_mock() |
| 404 | + |
| 405 | + # TODO: may only use one test after the llm caption method duplicate has been removed: |
| 406 | + # https://github.com/microsoft/markitdown/pull/1254 |
| 407 | + # Test PPTX file |
| 408 | + markitdown.convert(os.path.join(TEST_FILES_DIR, "test.pptx")) |
| 409 | + |
| 410 | + # Verify the prompt was passed to the OpenAI API for PPTX images too |
| 411 | + assert mock_client.chat.completions.create.called |
| 412 | + call_args = mock_client.chat.completions.create.call_args |
| 413 | + messages = call_args[1]["messages"] |
| 414 | + assert len(messages) == 1 |
| 415 | + assert messages[0]["content"][0]["text"] == test_prompt |
| 416 | + |
| 417 | + |
373 | 418 | @pytest.mark.skipif( |
374 | 419 | skip_llm, |
375 | 420 | reason="do not run llm tests without a key", |
@@ -408,6 +453,7 @@ def test_markitdown_llm() -> None: |
408 | 453 | test_speech_transcription, |
409 | 454 | test_exceptions, |
410 | 455 | test_markitdown_exiftool, |
| 456 | + test_markitdown_llm_parameters, |
411 | 457 | test_markitdown_llm, |
412 | 458 | ]: |
413 | 459 | print(f"Running {test.__name__}...", end="") |
|
0 commit comments