Skip to content

Commit

Permalink
Fixed remaining tests
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Jan 4, 2024
1 parent 0009c0c commit 527472b
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 36 deletions.
44 changes: 26 additions & 18 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Docs(BaseModel):
docnames: set[str] = set()
texts_index: VectorStore = NumpyVectorStore()
doc_index: VectorStore = NumpyVectorStore()
llm_config: dict = dict(model="gpt-3.5-turbo", model_type="chat")
llm_config: dict = dict(model="gpt-3.5-turbo", model_type="chat", temperature=0.1)
summary_llm_config: dict | None = Field(default=None, validate_default=True)
name: str = "default"
index_path: Path | None = PAPERQA_DIR / name
Expand All @@ -61,6 +61,7 @@ class Docs(BaseModel):
jit_texts_index: bool = False
# This is used to strip indirect citations that come up from the summary llm
strip_citations: bool = True
verbose: bool = False

def __init__(self, **data):
if "client" in data:
Expand Down Expand Up @@ -183,15 +184,17 @@ def add(
cite_chain = make_chain(
client=self._client,
prompt=self.prompts.cite,
llm_config=self.summary_llm_config,
llm_config=cast(dict, self.summary_llm_config),
skip_system=True,
)
# peak first chunk
fake_doc = Doc(docname="", citation="", dockey=dockey)
texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100)
if len(texts) == 0:
raise ValueError(f"Could not read document {path}. Is it empty?")
citation = asyncio.run(cite_chain(data=dict(text=texts[0].text)))
citation = asyncio.run(
cite_chain(dict(text=texts[0].text), None),
)
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"

Expand Down Expand Up @@ -312,23 +315,31 @@ async def adoc_match(
)
papers = [f"{d.docname}: {d.citation}" for d in matched_docs]
result = await chain(
data=[dict(question=query, papers="\n".join(papers))],
callbacks=get_callbacks("filter"),
dict(question=query, papers="\n".join(papers)),
get_callbacks("filter"),
)
return set([d.dockey for d in matched_docs if d.docname in result])
except AttributeError:
pass
return set([d.dockey for d in matched_docs])

def _build_texts_index(self, keys: set[DocKey] | None = None):
texts = self.texts
if keys is not None and self.jit_texts_index:
texts = self.texts
if keys is not None:
texts = [t for t in texts if t.doc.dockey in keys]
if len(texts) == 0:
return
self.texts_index.clear()
self.texts_index.add_texts_and_embeddings(texts)
if self.jit_texts_index and keys is None:
# Not sure what else to do here???????
print(
"Warning: JIT text index without keys "
"requires rebuilding index each time!"
)
self.texts_index.clear()
self.texts_index.add_texts_and_embeddings(texts)

def get_evidence(
self,
Expand Down Expand Up @@ -369,7 +380,6 @@ async def aget_evidence(
# do we have no docs?
return answer
self._build_texts_index(keys=answer.dockey_filter)
self.texts_index = cast(VectorStore, self.texts_index)
_k = k
if answer.dockey_filter is not None:
_k = k * 10 # heuristic - get enough so we can downselect
Expand Down Expand Up @@ -414,7 +424,7 @@ async def process(match):
summary_chain = make_chain(
client=self._client,
prompt=self.prompts.summary,
llm_config=self.summary_llm_config,
llm_config=cast(dict, self.summary_llm_config),
system_prompt=self.prompts.system,
)
# This is dangerous because it
Expand All @@ -425,14 +435,14 @@ async def process(match):
# http code in the exception
try:
context = await summary_chain(
data=dict(
dict(
question=answer.question,
# Add name so chunk is stated
citation=citation,
summary_length=answer.summary_length,
text=match.text,
),
callbacks=callbacks,
callbacks,
)
except Exception as e:
if guess_is_4xx(str(e)):
Expand Down Expand Up @@ -544,9 +554,7 @@ async def aquery(
llm_config=self.llm_config,
system_prompt=self.prompts.system,
)
pre = await chain(
data=dict(question=answer.question), callbacks=get_callbacks("pre")
)
pre = await chain(dict(question=answer.question), get_callbacks("pre"))
answer.context = answer.context + "\n\nExtra background information:" + pre
bib = dict()
if len(answer.context) < 10: # and not self.memory:
Expand All @@ -560,14 +568,16 @@ async def aquery(
llm_config=self.llm_config,
system_prompt=self.prompts.system,
)
print(answer.context)
answer_text = await qa_chain(
data=dict(
dict(
context=answer.context,
answer_length=answer.answer_length,
question=answer.question,
),
callbacks=get_callbacks("answer"),
get_callbacks("answer"),
)
print(answer_text)
# it still happens
if "(Example2012)" in answer_text:
answer_text = answer_text.replace("(Example2012)", "")
Expand All @@ -594,9 +604,7 @@ async def aquery(
llm_config=self.llm_config,
system_prompt=self.prompts.system,
)
post = await chain(
data=answer.model_dump(), callbacks=get_callbacks("post")
)
post = await chain(answer.model_dump(), get_callbacks("post"))
answer.answer = post
answer.formatted_answer = f"Question: {answer.question}\n\n{post}\n"
if len(bib) > 0:
Expand Down
4 changes: 2 additions & 2 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Awaitable, Callable, get_args, get_type_hints
from typing import Any, Callable, Coroutine, get_args, get_type_hints

from openai import AsyncOpenAI

Expand Down Expand Up @@ -54,7 +54,7 @@ def make_chain(
llm_config: dict,
skip_system: bool = False,
system_prompt: str = default_system_prompt,
) -> Awaitable[Any]:
) -> Callable[[dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, str]]:
"""Create a function to execute a batch of prompts
Args:
Expand Down
5 changes: 3 additions & 2 deletions paperqa/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

qa_prompt = (
"Write an answer ({answer_length}) "
"for the question below based on the provided context. "
"for the question below based on the provided context. Ignore irrelevant context. "
"If the context provides insufficient information and the question cannot be directly answered, "
'reply "I cannot answer". '
"For each part of your answer, indicate which sources most support it "
Expand All @@ -37,7 +37,8 @@
"Selected keys:"
)
citation_prompt = (
"Provide the citation for the following text in MLA Format. If reporting date accessed, the current year is 2024\n\n"
"Provide the citation for the following text in MLA Format. "
"If reporting date accessed, the current year is 2024\n\n"
"{text}\n\n"
"Citation:"
)
Expand Down
6 changes: 4 additions & 2 deletions paperqa/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def parse_txt(
) -> List[Text]:
"""Parse a document into chunks, based on tiktoken encoding.
NOTE: We get some byte continuation errors. Currnetly ignored, but should explore more to make sure we don't miss anything.
NOTE: We get some byte continuation errors.
Currnetly ignored, but should explore more to make sure we
don't miss anything.
"""
try:
with open(path) as f:
Expand All @@ -88,7 +90,7 @@ def parse_txt(
text = f.read()
if html:
text = html2text(text)
texts = []
texts: list[Text] = []
# we tokenize using tiktoken so cuts are in reasonable places
enc = tiktoken.get_encoding("cl100k_base")
encoded = [enc.decode_single_token_bytes(token) for token in enc.encode(text)]
Expand Down
18 changes: 13 additions & 5 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Just for clarity
DocKey = Any

CallbackFactory = Callable[[str], Callable[[str], None]]
CallbackFactory = Callable[[str], list[Callable[[str], None]] | None]


class Embeddable(BaseModel):
Expand Down Expand Up @@ -78,7 +78,7 @@ def max_marginal_relevance_search(
return texts, scores

embeddings = np.array([t.embedding for t in texts])
scores = np.array(scores)
np_scores = np.array(scores)
similarity_matrix = cosine_similarity(embeddings, embeddings)

selected_indices = [0]
Expand All @@ -88,7 +88,7 @@ def max_marginal_relevance_search(
selected_similarities = similarity_matrix[:, selected_indices]
max_sim_to_selected = selected_similarities.max(axis=1)

mmr_scores = lambda_ * scores - (1 - lambda_) * max_sim_to_selected
mmr_scores = lambda_ * np_scores - (1 - lambda_) * max_sim_to_selected
mmr_scores[selected_indices] = -np.inf # Exclude already selected documents

max_mmr_index = mmr_scores.argmax()
Expand Down Expand Up @@ -132,15 +132,21 @@ def similarity_search(
)


# Mock a dictionary and store any missing items
class _FormatDict(dict):
def __init__(self) -> None:
self.key_set: set[str] = set()

def __missing__(self, key: str) -> str:
self.key_set.add(key)
return key


def get_formatted_variables(s: str) -> set[str]:
"""Returns the set of variables implied by the format string"""
format_dict = _FormatDict()
s.format_map(format_dict)
return set(format_dict.keys())
return format_dict.key_set


class PromptCollection(BaseModel):
Expand Down Expand Up @@ -190,6 +196,8 @@ def check_select(cls, v: str) -> str:
@classmethod
def check_pre(cls, v: str | None) -> str | None:
if v is not None:
print(v)
print(get_formatted_variables(v))
if set(get_formatted_variables(v)) != set(["question"]):
raise ValueError("Pre prompt must have input variables: question")
return v
Expand All @@ -199,7 +207,7 @@ def check_pre(cls, v: str | None) -> str | None:
def check_post(cls, v: str | None) -> str | None:
if v is not None:
# kind of a hack to get list of attributes in answer
attrs = [a.name for a in Answer.__fields__.values()]
attrs = set(Answer.model_fields.keys())
if not set(get_formatted_variables(v)).issubset(attrs):
raise ValueError(f"Post prompt must have input variables: {attrs}")
return v
Expand Down
58 changes: 51 additions & 7 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import numpy as np
import requests
from openai import AsyncOpenAI

from paperqa import Answer, Doc, Docs, PromptCollection, Text
from paperqa.llms import get_score
from paperqa.llms import get_score, make_chain
from paperqa.readers import read_doc
from paperqa.utils import (
iter_citations,
Expand Down Expand Up @@ -359,6 +360,49 @@ def test_extract_score():
assert get_score(sample) == 9


class TestChains(IsolatedAsyncioTestCase):
async def test_chain_completion(self):
client = AsyncOpenAI()
call = make_chain(
client,
"The {animal} says",
llm_config=dict(
model_type="completion",
temperature=0,
model="babbage-002",
max_tokens=56,
),
skip_system=True,
)
outputs = []

def accum(x):
outputs.append(x)

completion = await call(dict(animal="duck"), callbacks=[accum])
assert completion == "".join(outputs)
assert type(completion) == str

async def test_chain_chat(self):
client = AsyncOpenAI()
call = make_chain(
client,
"The {animal} says",
llm_config=dict(
model_type="chat", temperature=0, model="gpt-3.5-turbo", max_tokens=56
),
skip_system=True,
)
outputs = []

def accum(x):
outputs.append(x)

completion = await call(dict(animal="duck"), callbacks=[accum])
assert completion == "".join(outputs)
assert type(completion) == str


def test_docs():
llm_config = dict(temperature=0.1, model="text-ada-001", model_type="completion")
docs = Docs(llm_config=llm_config)
Expand Down Expand Up @@ -454,7 +498,7 @@ def test_docs_pickle():
# get front page of wikipedia
r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day")
f.write(r.text)
docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002"))
docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo"))
old_config = docs.llm_config
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000)
os.remove(doc_path)
Expand Down Expand Up @@ -511,7 +555,7 @@ def test_repeat_keys():
# get wiki page about politician
r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)")
f.write(r.text)
docs = Docs(llm_config=dict(temperature=0.0, model="text-ada-001"))
docs = Docs(llm_config=dict(temperature=0.0, model="babbage-002"))
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now")
try:
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now")
Expand Down Expand Up @@ -540,7 +584,7 @@ def test_repeat_keys():
def test_pdf_reader():
tests_dir = os.path.dirname(os.path.abspath(__file__))
doc_path = os.path.join(tests_dir, "paper.pdf")
docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002"))
docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo"))
docs.add(doc_path, "Wellawatte et al, XAI Review, 2023")
answer = docs.query("Are counterfactuals actionable?")
assert "yes" in answer.answer or "Yes" in answer.answer
Expand All @@ -550,15 +594,15 @@ def test_fileio_reader_pdf():
tests_dir = os.path.dirname(os.path.abspath(__file__))
doc_path = os.path.join(tests_dir, "paper.pdf")
with open(doc_path, "rb") as f:
docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002"))
docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo"))
docs.add_file(f, "Wellawatte et al, XAI Review, 2023")
answer = docs.query("Are counterfactuals actionable?")
assert "yes" in answer.answer or "Yes" in answer.answer


def test_fileio_reader_txt():
# can't use curie, because it has trouble with parsed HTML
docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002"))
docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo"))
r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)")
if r.status_code != 200:
raise ValueError("Could not download wikipedia page")
Expand All @@ -568,7 +612,7 @@ def test_fileio_reader_txt():
chunk_chars=1000,
)
answer = docs.query("What country was Frederick Bates born in?")
assert "Virginia" in answer.answer
assert "United States" in answer.answer


def test_pdf_pypdf_reader():
Expand Down

0 comments on commit 527472b

Please sign in to comment.