ravfogs's picture
Update README.md
7f41f16
metadata
language:
  - en
tags:
  - feature-extraction
  - sentence-similarity
datasets:
  - biu-nlp/abstract-sim
widgets:
  - sentence-similarity
  - feature-extraction

A model for mapping abstract sentence descriptions to sentences that fit the descriptions. Trained on Wikipedia. Use load_finetuned_model to load the query and sentence encoder, and encode_batch() to encode a sentence with the model.

Note: the method uses a dual encoder architecture. This is the sentence encoder; it should be used alongside the Query encoder.


from transformers import AutoTokenizer, AutoModel
import torch
from typing import List
from sklearn.metrics.pairwise import cosine_similarity

def load_finetuned_model():


        sentence_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-sentence")
        query_encoder = AutoModel.from_pretrained("biu-nlp/abstract-sim-query")
        tokenizer = AutoTokenizer.from_pretrained("biu-nlp/abstract-sim-sentence")

        return tokenizer, query_encoder, sentence_encoder


def encode_batch(model, tokenizer, sentences: List[str], device: str):
    input_ids = tokenizer(sentences, padding=True, max_length=512, truncation=True, return_tensors="pt",
                          add_special_tokens=True).to(device)
    features = model(**input_ids)[0]
    features =  torch.sum(features[:,1:,:] * input_ids["attention_mask"][:,1:].unsqueeze(-1), dim=1) / torch.clamp(torch.sum(input_ids["attention_mask"][:,1:], dim=1, keepdims=True), min=1e-9)
    return features

Usage example:

tokenizer, query_encoder, sentence_encoder = load_finetuned_model()
relevant_sentences = ["Fingersoft's parent company is the Finger Group.",
                      "WHIRC – a subsidiary company of Wright-Hennepin",
                      "CK Life Sciences International (Holdings) Inc. (), or CK Life Sciences, is a subsidiary of CK Hutchison Holdings",
                      "EM Microelectronic-Marin (subsidiary of The Swatch Group).",
                      "The company is currently a division of the corporate group Jam Industries.",
                      "Volt Technical Resources is a business unit of Volt Workforce Solutions, a subsidiary of Volt Information Sciences (currently trading over-the-counter as VISI.)."
             ]

irrelevant_sentences = ["The second company is deemed to be a subsidiary of the parent company.",
                        "The company has gone through more than one incarnation.",
                        "The company is owned by its employees.",
                        "Larger companies compete for market share by acquiring smaller companies that may own a particular market sector.",
                        "A parent company is a company that owns 51% or more voting stock in another firm (or subsidiary).",
                        "It is a holding company that provides services through its subsidiaries in the following areas: oil and gas, industrial and infrastructure, government and power.",
                        "RXVT Technologies is no longer a subsidiary of the parent company."
                        ]

all_sentences = relevant_sentences + irrelevant_sentences
query = "<query>: A company is a part of a larger company."
    
embeddings = encode_batch(sentence_encoder, tokenizer, all_sentences, "cpu").detach().cpu().numpy()
query_embedding = encode_batch(query_encoder, tokenizer, [query], "cpu").detach().cpu().numpy()

sims = cosine_similarity(query_embedding, embeddings)[0]
sentences_sims = list(zip(all_sentences, sims))
sentences_sims.sort(key=lambda x: x[1], reverse=True)

for s, sim in sentences_sims:
    print(s, sim)

Expected output:

WHIRC – a subsidiary company of Wright-Hennepin 0.9396286
EM Microelectronic-Marin (subsidiary of The Swatch Group). 0.93929046
Fingersoft's parent company is the Finger Group. 0.936247
CK Life Sciences International (Holdings) Inc. (), or CK Life Sciences, is a subsidiary of CK Hutchison Holdings 0.9350312
The company is currently a division of the corporate group Jam Industries. 0.9273489
Volt Technical Resources is a business unit of Volt Workforce Solutions, a subsidiary of Volt Information Sciences (currently trading over-the-counter as VISI.). 0.9005086
The second company is deemed to be a subsidiary of the parent company. 0.6723645
It is a holding company that provides services through its subsidiaries in the following areas: oil and gas, industrial and infrastructure, government and power. 0.60081375
A parent company is a company that owns 51% or more voting stock in another firm (or subsidiary). 0.59490484
The company is owned by its employees. 0.55286574
RXVT Technologies is no longer a subsidiary of the parent company. 0.4321953
The company has gone through more than one incarnation. 0.38889483
Larger companies compete for market share by acquiring smaller companies that may own a particular market sector. 0.25472647