Skip to content

Commit d0b56a3

Browse files
authored
Only show the top contexts to the agent for the specific question (#912)
1 parent 8501bc8 commit d0b56a3

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

paperqa/agents/tools.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,15 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
251251

252252
status = state.status
253253
logger.info(status)
254+
# only show top n contexts for this particular question to the agent
254255
sorted_contexts = sorted(
255-
state.session.contexts, key=lambda x: x.score, reverse=True
256+
[
257+
c
258+
for c in state.session.contexts
259+
if (c.question is None or c.question == question)
260+
],
261+
key=lambda x: x.score,
262+
reverse=True,
256263
)
257264

258265
top_contexts = "\n".join(

paperqa/core.py

+1
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ async def map_fxn_summary(
207207
return (
208208
Context(
209209
context=context,
210+
question=question,
210211
text=Text(
211212
text=text.text,
212213
name=text.name,

paperqa/types.py

+8
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,13 @@ class Context(BaseModel):
110110
model_config = ConfigDict(extra="allow")
111111

112112
context: str = Field(description="Summary of the text with respect to a question.")
113+
question: str | None = Field(
114+
default=None,
115+
description=(
116+
"Question that the context is summarizing for. "
117+
"Note this can differ from the user query."
118+
),
119+
)
113120
text: Text
114121
score: int = 5
115122

@@ -236,6 +243,7 @@ def filter_content_for_user(self) -> None:
236243
self.contexts = [
237244
Context(
238245
context=c.context,
246+
question=c.question,
239247
score=c.score,
240248
text=Text(
241249
text="",

tests/test_agents.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -670,10 +670,27 @@ def new_status(state: EnvironmentState) -> str:
670670

671671
# now adjust to give the agent 2x pieces of evidence
672672
gather_evidence_tool.settings.agent.agent_evidence_n = 2
673+
# also reset the question to ensure that contexts are
674+
# only returned to the agent for the new question
675+
new_question = "How does XAI relate to a self-explanatory model?"
673676
response = await gather_evidence_tool.gather_evidence(
674-
session.question, state=env_state
677+
new_question, state=env_state
675678
)
676-
679+
assert len({c.question for c in session.contexts}) == 2, "Expected 2 questions"
680+
# now we make sure this is only for the old question
681+
for context in session.contexts:
682+
if context.question != new_question:
683+
assert (
684+
context.context[:20] not in response
685+
), "gather_evidence should not return any contexts for the old question"
686+
assert (
687+
sum(
688+
(1 if (context.context[:20] in response) else 0)
689+
for context in session.contexts
690+
if context.question == new_question
691+
)
692+
== 2
693+
), "gather_evidence should only return 2 contexts for the new question"
677694
split = re.split(
678695
r"(\d+) pieces of evidence, (\d+) of which were relevant",
679696
response,
@@ -899,6 +916,7 @@ def test_answers_are_striped() -> None:
899916
contexts=[
900917
Context(
901918
context="bla",
919+
question="foo",
902920
text=Text(
903921
name="text",
904922
text="The meaning of life is 42.",

0 commit comments

Comments
 (0)