abstracts-index / app.py
colonelwatch's picture
Pull in index from new repository, due to LFS size limits on HF Spaces
92365df
raw
history blame
12 kB
# app.py
# Loads all completed shards and finds the most similar vector to a given query vector.
from dataclasses import dataclass
from itertools import batched, chain
import json
import os
from math import log10
from pathlib import Path
from sys import stderr
from typing import TypedDict, Self, Any, Callable
from datasets import Dataset, disable_caching
from datasets.search import FaissIndex
import faiss
import gradio as gr
import requests
from sentence_transformers import SentenceTransformer
import torch
class IndexParameters(TypedDict):
recall: float # in this case 10-recall@10
exec_time: float # seconds (raw faiss measure is in milliseconds)
param_string: str # pass directly to faiss index
class Params(TypedDict):
dimensions: int | None
normalize: bool
optimal_params: list[IndexParameters]
@dataclass
class Work:
title: str | None
abstract: str | None # recovered from abstract_inverted_index
authors: list[str] # takes raw_author_name field from Authorship objects
journal_name: str | None # takes the display_name field of the first location
year: int
citations: int
doi: str | None
def __post_init__(self):
self._check_type(self.title, str, nullable=True)
self._check_type(self.abstract, str, nullable=True)
self._check_type(self.authors, list)
for author in self.authors:
self._check_type(author, str)
self._check_type(self.journal_name, str, nullable=True)
self._check_type(self.year, int)
self._check_type(self.citations, int)
self._check_type(self.doi, str, nullable=True)
@classmethod
def from_dict(cls, d: dict) -> Self:
inverted_index: None | dict[str, list[int]] = d["abstract_inverted_index"]
abstract = cls._recover_abstract(inverted_index) if inverted_index else None
try:
journal_name = d["primary_location"]["source"]["display_name"]
except (TypeError, KeyError): # key didn't exist or a value was null
journal_name = None
return cls(
title=d["title"],
abstract=abstract,
authors=[authorship["raw_author_name"] for authorship in d["authorships"]],
journal_name=journal_name,
year=d["publication_year"],
citations=d["cited_by_count"],
doi=d["doi"],
)
@staticmethod
def get_raw_fields() -> list[str]:
return [
"title",
"abstract_inverted_index",
"authorships",
"primary_location",
"publication_year",
"cited_by_count",
"doi"
]
@staticmethod
def _check_type(v: Any, t: type, nullable: bool = False):
if not ((nullable and v is None) or isinstance(v, t)):
v_type_name = f"{type(v)}" if v is not None else "None"
t_name = f"{t}"
if nullable:
t_name += " | None"
raise ValueError(f"expected {t_name}, got {v_type_name}")
@staticmethod
def _recover_abstract(inverted_index: dict[str, list[int]]) -> str:
abstract_size = max(max(locs) for locs in inverted_index.values())+1
abstract_words: list[str | None] = [None] * abstract_size
for word, locs in inverted_index.items():
for loc in locs:
abstract_words[loc] = word
return " ".join(word for word in abstract_words if word is not None)
def get_env_var[T, U](
key: str, type_: Callable[[str], T] = str, default: U = None
) -> T | U:
var = os.getenv(key)
if var is not None:
var = type_(var)
else:
var = default
return var
def get_model(
model_name: str, params_dir: Path, trust_remote_code: bool
) -> tuple[bool, SentenceTransformer]:
# TODO: params["normalize"] for models like all-MiniLM-v6, which already normalize?
with open(params_dir / "params.json", "r") as f:
params: Params = json.load(f)
return params["normalize"], SentenceTransformer(
model_name,
trust_remote_code=trust_remote_code,
truncate_dim=params["dimensions"]
)
def open_ondisk(dir: Path) -> faiss.Index:
# without IO_FLAG_ONDISK_SAME_DIR, read_index gets on-disk indices in working dir
return faiss.read_index(str(dir / "index.faiss"), faiss.IO_FLAG_ONDISK_SAME_DIR)
def get_index(dir: Path, search_time_s: float) -> Dataset:
# NOTE: use a private attr to load the index with IO_FLAG_ONDISK_SAME_DIR!
index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore
faiss_index = open_ondisk(dir)
index._indexes["embedding"] = FaissIndex(None, None, None, faiss_index)
with open(dir / "params.json", "r") as f:
params: Params = json.load(f)
under = [p for p in params["optimal_params"] if p["exec_time"] < search_time_s]
optimal = max(under, key=(lambda p: p["recall"]))
optimal_string = optimal["param_string"]
ps = faiss.ParameterSpace()
ps.initialize(faiss_index)
ps.set_index_parameters(faiss_index, optimal_string)
return index
def execute_request(ids: list[str], mailto: str | None) -> list[Work]:
if len(ids) > 100:
raise ValueError("querying /works endpoint with more than 100 works")
# query with the /works endpoint with a specific list of IDs and fields
search_filter = f"openalex_id:{"|".join(ids)}"
search_select = ",".join(["id"] + Work.get_raw_fields())
params = {"filter": search_filter, "select": search_select, "per-page": 100}
if mailto is not None:
params["mailto"] = mailto
response = requests.get("https://api.openalex.org/works", params)
response.raise_for_status()
# the response is not necessarily ordered, so order them
response = {d["id"]: Work.from_dict(d) for d in response.json()["results"]}
return [response[id_] for id_ in ids]
def collapse_newlines(x: str) -> str:
return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ")
def format_response(
neighbors: list[Work], distances: list[float], calculate_similarity: bool = False
) -> str:
result_string = ""
for work, distance in zip(neighbors, distances):
entry_string = "## "
if work.title and work.doi:
entry_string += f"[{collapse_newlines(work.title)}]({work.doi})"
elif work.title:
entry_string += f"{collapse_newlines(work.title)}"
elif work.doi:
entry_string += f"[No title]({work.doi})"
else:
entry_string += "No title"
entry_string += "\n\n**"
if len(work.authors) >= 3: # truncate to 3 if necessary
entry_string += ", ".join(work.authors[:3]) + ", ..."
elif work.authors:
entry_string += ", ".join(work.authors)
else:
entry_string += "No author"
entry_string += f", {work.year}"
if work.journal_name:
entry_string += " - " + work.journal_name
entry_string += "**\n\n"
if work.abstract:
abstract = collapse_newlines(work.abstract)
if len(abstract) > 2000:
abstract = abstract[:2000] + "..."
entry_string += abstract
else:
entry_string += "No abstract"
entry_string += "\n\n*"
meta: list[tuple[str, str]] = []
if work.citations: # don't tack "Cited-by count: 0" on someones's work
meta.append(("Cited-by count", str(work.citations)))
if work.doi:
meta.append(("DOI", work.doi.replace("https://doi.org/", "")))
if calculate_similarity:
# if query and result are unit vectors, the cosine sim is 1 - dist^2 / 2
meta.append(("Similarity", f"{1 - distance / 2:.2f}")) # faiss gives dist^2
else:
meta.append(("Distance", f"{distance:.2f}"))
entry_string += ("&nbsp;" * 4).join(": ".join(tup) for tup in meta)
entry_string += "*\n"
result_string += entry_string
return result_string
def main():
# TODO: figure out some better defaults?
model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
prompt_name = get_env_var("PROMPT_NAME")
trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False)
fp16 = get_env_var("FP16", bool, default=False)
dir = get_env_var("DIR", Path, default=Path("faiss/index"))
search_time_s = get_env_var("SEARCH_TIME_S", float, default=1)
k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
mailto = get_env_var("MAILTO", str, None)
disable_caching() # disable caching in the datasets library
normalize, model = get_model(model_name, dir, trust_remote_code)
index = get_index(dir, search_time_s)
model.eval()
if torch.cuda.is_available():
model = model.half().cuda() if fp16 else model.bfloat16().cuda()
# TODO: if huggingface datasets exposes an fp16 gpu option, use it here
elif fp16:
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
model.compile(mode="reduce-overhead")
# function signature: (expanded tuple of input batches) -> tuple of output batches
def search(query: list[str]) -> tuple[list[str]]:
query_embedding = model.encode(
query, prompt_name, normalize_embeddings=normalize
)
distances, faiss_ids = index.search_batch("embedding", query_embedding, k)
faiss_ids_flat = list(chain(*faiss_ids))
openalex_ids_flat = index[faiss_ids_flat]["id"]
works_flat = execute_request(openalex_ids_flat, mailto)
works = [list(batch) for batch in batched(works_flat, k)]
result_strings = [
format_response(w, d, calculate_similarity=normalize)
for w, d in zip(works, distances)
]
return (result_strings, )
with gr.Blocks() as demo:
# figure out the words to describe the quantity
n_entries = len(index)
n_digits = int(log10(n_entries))
divisor, postfix = {
0: (1, ""),
1: (1000, " thousand"),
2: (1000000, " million"),
3: (1000000000, " billion"),
}[n_digits // 3]
significand = n_entries / divisor
significand = round(significand, 1 if (n_digits % 3 == 1) else None)
quantity = str(significand) + postfix
# split the (huggingface) model name and get the link
model_publisher, model_human_name = model_name.split("/")
model_link = f"https://huggingface.co/{model_publisher}/{model_human_name}"
gr.Markdown("# abstracts-index")
gr.Markdown(
f"Explore {quantity} academic publications selected from the "
"[OpenAlex](https://openalex.org) dataset (as of January 1st, 2025). This "
"project is an index of the embeddings generated from their titles and "
"abstracts. The embeddings were generated using the "
f"[{model_human_name}]({model_link}) model, and the index was built using "
"the [faiss](https://github.com/facebookresearch/faiss) module. The build "
"scripts and more information available at the main repo "
"[abstracts-search](https://github.com/colonelwatch/abstracts-search) on "
"Github."
)
query = gr.Textbox(
lines=1, placeholder="Enter your query here", show_label=False
)
btn = gr.Button("Search")
results = gr.Markdown(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": False},
{"left": "$", "right": "$", "display": False},
],
container=True,
)
query.submit(search, inputs=[query], outputs=[results], batch=True)
btn.click(search, inputs=[query], outputs=[results], batch=True)
demo.queue()
demo.launch()
if __name__ == "__main__":
main()