mirror of
https://github.com/LearningCircuit/local-deep-research.git
synced 2026-06-15 19:46:56 +03:00
fix(llm): normalize str-returns to a message in the central LLM wrapper (+ ainvoke) (#4342)
* fix(llm): normalize str-returns to a message in ProcessingLLMWrapper + add ainvoke The central wrapper (returned by all get_llm paths) stripped <think> tags but returned a bare str when the base LLM returned a str — that inconsistent shape is the root of the recurring "'str' object has no attribute 'content'" crashes we've been fixing site-by-site (#3884 -> #4339). Generic fix at the choke point: - invoke(): when the base returns a bare str, wrap it into AIMessage(content=stripped) instead of returning a str. Message returns are unchanged (mutate .content in place, preserving additional_kwargs/reasoning_content/tool_calls). Other types pass through. - add ainvoke(): mirrors invoke(); without it, the 7 direct .ainvoke() sites (browsecomp_entity/modular strategies) bypassed think-stripping via __getattr__. Now every get_llm LLM yields a think-free str .content on both sync and async direct calls, so the raw .invoke().content sites are safe automatically (deferred per-site migration cancelled). Reasoning-safe: only .content is rewritten, so DeepSeek thinking-mode reasoning_content round-tripping (#4194) is not worsened. Limitation: the LangGraph create_agent path binds tools on the base model (model.bind_tools via __getattr__), so it bypasses this wrapper — unchanged by this PR. Tests: updated the 2 tests asserting a str return; added shape, reasoning_content/ tool_calls-preservation (#4194 guard), and ainvoke regression tests. mypy 552 clean; ruff clean; 2171 passed across 78 LLM-layer test files + citation_handlers. * refactor(llm): extract _log_llm_error helper + add type hints (review polish) Addresses the #4342 review recommendations: - DRY: invoke() and ainvoke() shared the same try/except error-logging verbatim; extracted a _log_llm_error(error) static helper so they can't diverge. - Type hints: added annotations to _normalize_response/_log_llm_error/invoke/ainvoke. No behavior change. ruff + mypy clean; 112 config tests pass.
This commit is contained in:
1
changelog.d/+normalize-llm-wrapper-str-returns.bugfix.md
Normal file
1
changelog.d/+normalize-llm-wrapper-str-returns.bugfix.md
Normal file
@@ -0,0 +1 @@
|
||||
The central LLM wrapper now normalizes string-returning providers into a message object and applies `<think>`-tag stripping to async (`ainvoke`) calls too, so any LLM obtained from `get_llm` yields a consistent, think-free `.content` — eliminating `'str' object has no attribute 'content'` crashes at the source. Message objects keep their `tool_calls`/`reasoning_content` (only `.content` is rewritten).
|
||||
@@ -1,7 +1,9 @@
|
||||
from functools import cache
|
||||
from typing import Any
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_openai import ChatOpenAI
|
||||
from loguru import logger
|
||||
@@ -780,29 +782,50 @@ def wrap_llm_without_think_tags(
|
||||
def __init__(self, base_llm):
|
||||
self.base_llm = base_llm
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
# Removed verbose debug logging to reduce log clutter
|
||||
# Uncomment the lines below if you need to debug LLM requests
|
||||
try:
|
||||
response = self.base_llm.invoke(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception("LLM Request - Failed with error")
|
||||
# Log any URL information from the error
|
||||
error_str = str(e)
|
||||
if "http://" in error_str or "https://" in error_str:
|
||||
logger.exception(
|
||||
f"LLM Request - Error contains URL info: {error_str}"
|
||||
)
|
||||
raise
|
||||
@staticmethod
|
||||
def _normalize_response(response: Any) -> Any:
|
||||
"""Strip <think> tags and normalize the response shape.
|
||||
|
||||
# Process the response content if it has a content attribute
|
||||
A message keeps its object identity (only ``.content`` is rewritten,
|
||||
so ``additional_kwargs``/``reasoning_content``/``tool_calls`` survive).
|
||||
A bare-string return (some providers/wrappers) is wrapped into an
|
||||
``AIMessage`` so callers can always rely on ``.content``. Anything
|
||||
else is passed through unchanged.
|
||||
"""
|
||||
if hasattr(response, "content"):
|
||||
response.content = remove_think_tags(response.content)
|
||||
elif isinstance(response, str):
|
||||
response = remove_think_tags(response)
|
||||
|
||||
response = AIMessage(content=remove_think_tags(response))
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _log_llm_error(error: Exception) -> None:
|
||||
"""Log an LLM call failure, surfacing any URL embedded in the error."""
|
||||
logger.exception("LLM Request - Failed with error")
|
||||
error_str = str(error)
|
||||
if "http://" in error_str or "https://" in error_str:
|
||||
logger.exception(
|
||||
f"LLM Request - Error contains URL info: {error_str}"
|
||||
)
|
||||
|
||||
def invoke(self, *args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
response = self.base_llm.invoke(*args, **kwargs)
|
||||
except Exception as e:
|
||||
self._log_llm_error(e)
|
||||
raise
|
||||
return self._normalize_response(response)
|
||||
|
||||
async def ainvoke(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# Async counterpart of invoke(); without this, ainvoke() would fall
|
||||
# through __getattr__ to the base LLM and bypass think-tag stripping.
|
||||
try:
|
||||
response = await self.base_llm.ainvoke(*args, **kwargs)
|
||||
except Exception as e:
|
||||
self._log_llm_error(e)
|
||||
raise
|
||||
return self._normalize_response(response)
|
||||
|
||||
# Pass through any other attributes to the base LLM
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.base_llm, name)
|
||||
|
||||
@@ -995,9 +995,11 @@ class TestWrapperStringResponse:
|
||||
):
|
||||
wrapper = wrap_llm_without_think_tags(mock_llm)
|
||||
result = wrapper.invoke("test")
|
||||
# Should remove think tags from string
|
||||
assert "answer" in result
|
||||
assert "<think>" not in result
|
||||
# A bare-string return is normalized into a message (so callers can
|
||||
# always use .content); think tags are still removed.
|
||||
assert not isinstance(result, str)
|
||||
assert "answer" in result.content
|
||||
assert "<think>" not in result.content
|
||||
|
||||
def test_wrapper_handles_invoke_exception(self):
|
||||
"""Should propagate exceptions from LLM invoke."""
|
||||
|
||||
@@ -726,9 +726,11 @@ class TestWrapLlmWithoutThinkTags:
|
||||
llm.invoke.return_value = "<think>thought</think>answer"
|
||||
w = self._make_wrapper(llm)
|
||||
result = w.invoke("prompt")
|
||||
assert isinstance(result, str)
|
||||
assert "<think>" not in result
|
||||
assert "answer" in result
|
||||
# A bare-string return is wrapped into a message so callers can rely on
|
||||
# .content; think tags are still stripped.
|
||||
assert not isinstance(result, str)
|
||||
assert "<think>" not in result.content
|
||||
assert "answer" in result.content
|
||||
|
||||
def test_response_without_content_attr_returned_as_is(self):
|
||||
"""Response that is neither string nor has .content is passed through."""
|
||||
@@ -739,6 +741,51 @@ class TestWrapLlmWithoutThinkTags:
|
||||
result = w.invoke("prompt")
|
||||
assert result == 42
|
||||
|
||||
def test_string_response_wrapped_in_message(self):
|
||||
"""A bare-string return is wrapped into an AIMessage with .content."""
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = "<think>t</think>final"
|
||||
w = self._make_wrapper(llm)
|
||||
result = w.invoke("prompt")
|
||||
assert isinstance(result, AIMessage)
|
||||
assert result.content == "final"
|
||||
|
||||
def test_preserves_reasoning_content_and_tool_calls(self):
|
||||
"""Stripping <think> from .content must NOT drop reasoning_content/tool_calls.
|
||||
|
||||
Guards against worsening DeepSeek thinking-mode round-tripping (#4194):
|
||||
we only rewrite .content in place, leaving the rest of the message intact.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = AIMessage(
|
||||
content="<think>reasoning</think>answer",
|
||||
additional_kwargs={"reasoning_content": "R"},
|
||||
tool_calls=[
|
||||
{"name": "search", "args": {}, "id": "1", "type": "tool_call"}
|
||||
],
|
||||
)
|
||||
w = self._make_wrapper(llm)
|
||||
result = w.invoke("prompt")
|
||||
assert result.content == "answer"
|
||||
assert result.additional_kwargs["reasoning_content"] == "R"
|
||||
assert result.tool_calls and result.tool_calls[0]["name"] == "search"
|
||||
|
||||
def test_ainvoke_normalizes_string_response(self):
|
||||
"""ainvoke applies the same normalization as invoke (str -> message)."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value="<think>t</think>async answer")
|
||||
w = self._make_wrapper(llm)
|
||||
result = asyncio.run(w.ainvoke("prompt"))
|
||||
assert not isinstance(result, str)
|
||||
assert result.content == "async answer"
|
||||
|
||||
def test_invoke_exception_propagated(self):
|
||||
llm = MagicMock()
|
||||
llm.invoke.side_effect = ConnectionError("timeout")
|
||||
|
||||
Reference in New Issue
Block a user