Skip to content

Commit b81a387

Browse files
authored
fix: correctly pass custom llm prompt parameter (microsoft#1319)
* fix: correctly pass custom llm prompt parameter
1 parent ea1a3df commit b81a387

3 files changed

Lines changed: 53 additions & 2 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,14 @@ result = md.convert("test.pdf")
164164
print(result.text_content)
165165
```
166166

167-
To use Large Language Models for image descriptions, provide `llm_client` and `llm_model`:
167+
To use Large Language Models for image descriptions (currently only for pptx and image files), provide `llm_client` and `llm_model`:
168168

169169
```python
170170
from markitdown import MarkItDown
171171
from openai import OpenAI
172172

173173
client = OpenAI()
174-
md = MarkItDown(llm_client=client, llm_model="gpt-4o")
174+
md = MarkItDown(llm_client=client, llm_model="gpt-4o", llm_prompt="optional custom prompt")
175175
result = md.convert("example.jpg")
176176
print(result.text_content)
177177
```

packages/markitdown/src/markitdown/_markitdown.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
# TODO - remove these (see enable_builtins)
116116
self._llm_client: Any = None
117117
self._llm_model: Union[str | None] = None
118+
self._llm_prompt: Union[str | None] = None
118119
self._exiftool_path: Union[str | None] = None
119120
self._style_map: Union[str | None] = None
120121

@@ -139,6 +140,7 @@ def enable_builtins(self, **kwargs) -> None:
139140
# TODO: Move these into converter constructors
140141
self._llm_client = kwargs.get("llm_client")
141142
self._llm_model = kwargs.get("llm_model")
143+
self._llm_prompt = kwargs.get("llm_prompt")
142144
self._exiftool_path = kwargs.get("exiftool_path")
143145
self._style_map = kwargs.get("style_map")
144146

@@ -559,6 +561,9 @@ def _convert(
559561
if "llm_model" not in _kwargs and self._llm_model is not None:
560562
_kwargs["llm_model"] = self._llm_model
561563

564+
if "llm_prompt" not in _kwargs and self._llm_prompt is not None:
565+
_kwargs["llm_prompt"] = self._llm_prompt
566+
562567
if "style_map" not in _kwargs and self._style_map is not None:
563568
_kwargs["style_map"] = self._style_map
564569

packages/markitdown/tests/test_module_misc.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import re
55
import shutil
66
import pytest
7+
from unittest.mock import MagicMock
78

89
from markitdown._uri_utils import parse_data_uri, file_uri_to_path
910

@@ -370,6 +371,50 @@ def test_markitdown_exiftool() -> None:
370371
assert target in result.text_content
371372

372373

374+
def test_markitdown_llm_parameters() -> None:
375+
"""Test that LLM parameters are correctly passed to the client."""
376+
mock_client = MagicMock()
377+
mock_response = MagicMock()
378+
mock_response.choices = [
379+
MagicMock(
380+
message=MagicMock(
381+
content="Test caption with red circle and blue square 5bda1dd6"
382+
)
383+
)
384+
]
385+
mock_client.chat.completions.create.return_value = mock_response
386+
387+
test_prompt = "You are a professional test prompt."
388+
markitdown = MarkItDown(
389+
llm_client=mock_client, llm_model="gpt-4o", llm_prompt=test_prompt
390+
)
391+
392+
# Test image file
393+
markitdown.convert(os.path.join(TEST_FILES_DIR, "test_llm.jpg"))
394+
395+
# Verify the prompt was passed to the OpenAI API
396+
assert mock_client.chat.completions.create.called
397+
call_args = mock_client.chat.completions.create.call_args
398+
messages = call_args[1]["messages"]
399+
assert len(messages) == 1
400+
assert messages[0]["content"][0]["text"] == test_prompt
401+
402+
# Reset the mock for the next test
403+
mock_client.chat.completions.create.reset_mock()
404+
405+
# TODO: may only use one test after the llm caption method duplicate has been removed:
406+
# https://github.com/microsoft/markitdown/pull/1254
407+
# Test PPTX file
408+
markitdown.convert(os.path.join(TEST_FILES_DIR, "test.pptx"))
409+
410+
# Verify the prompt was passed to the OpenAI API for PPTX images too
411+
assert mock_client.chat.completions.create.called
412+
call_args = mock_client.chat.completions.create.call_args
413+
messages = call_args[1]["messages"]
414+
assert len(messages) == 1
415+
assert messages[0]["content"][0]["text"] == test_prompt
416+
417+
373418
@pytest.mark.skipif(
374419
skip_llm,
375420
reason="do not run llm tests without a key",
@@ -408,6 +453,7 @@ def test_markitdown_llm() -> None:
408453
test_speech_transcription,
409454
test_exceptions,
410455
test_markitdown_exiftool,
456+
test_markitdown_llm_parameters,
411457
test_markitdown_llm,
412458
]:
413459
print(f"Running {test.__name__}...", end="")

0 commit comments

Comments
 (0)