|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
import os |
|
|
|
class TextExtractor: |
|
def __init__(self, model_name, proxy=None): |
|
""" |
|
Initialize the TextExtractor with a specified model and optional proxy settings. |
|
|
|
Parameters: |
|
- model_name (str): The name of the pre-trained model to load from HuggingFace Hub. |
|
- proxy (str, optional): The proxy address to use for HTTP and HTTPS requests. |
|
""" |
|
if proxy is None: |
|
proxy = 'http://localhost:8234' |
|
|
|
if proxy: |
|
os.environ['HTTP_PROXY'] = proxy |
|
os.environ['HTTPS_PROXY'] = proxy |
|
try: |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModel.from_pretrained(model_name) |
|
except: |
|
print('try switch on local_files_only') |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) |
|
self.model = AutoModel.from_pretrained(model_name, local_files_only=True) |
|
|
|
self.model.eval() |
|
|
|
def extract(self, sentences): |
|
""" |
|
Extract sentence embeddings for the provided sentences. |
|
|
|
Parameters: |
|
- sentences (list of str): A list of sentences to extract embeddings for. |
|
|
|
Returns: |
|
- torch.Tensor: The normalized sentence embeddings. |
|
""" |
|
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') |
|
|
|
with torch.no_grad(): |
|
model_output = self.model(**encoded_input) |
|
sentence_embeddings = model_output[0][:, 0] |
|
|
|
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) |
|
return sentence_embeddings |
|
|
|
import pandas as pd |
|
def get_qas(excel_file = None): |
|
|
|
defaule_excel_file = 'data/output_fixid.xlsx' |
|
if excel_file is None: |
|
excel_file = defaule_excel_file |
|
|
|
|
|
df = pd.read_excel(excel_file) |
|
|
|
df = df[df["question"].notna()] |
|
df = df[df["summary"].notna()] |
|
|
|
datas = [] |
|
|
|
|
|
for index, row in df.iterrows(): |
|
id = row['id'] |
|
question = row['question'] |
|
short_answer = row['summary'] |
|
category = row['category'] |
|
|
|
texts = [question, short_answer] |
|
|
|
data_value = { |
|
"texts":texts, |
|
} |
|
|
|
data = { |
|
"id":id, |
|
"value":data_value |
|
} |
|
|
|
datas.append(data) |
|
|
|
return datas |
|
|
|
from tqdm import tqdm |
|
|
|
def extract_embedding(datas, text_extractor): |
|
""" |
|
Extract embeddings for each item in the provided data. |
|
|
|
Parameters: |
|
- datas (list of dict): A list of dictionaries containing text data. |
|
|
|
Returns: |
|
- list of dict: The input data with added embeddings. |
|
""" |
|
for data in tqdm(datas): |
|
texts = data["value"]["texts"] |
|
text = "。".join(texts) |
|
embeddings = text_extractor.extract(text) |
|
embeddings_list = embeddings.tolist() |
|
data["value"]["embedding"] = embeddings_list |
|
return datas |
|
|
|
def save_parquet(datas, file_path): |
|
""" |
|
Save the provided data to a Parquet file. |
|
|
|
Parameters: |
|
- datas (list of dict): A list of dictionaries containing text data and embeddings. |
|
- file_path (str): The path to the output Parquet file. |
|
""" |
|
|
|
flattened_data = [] |
|
for data in datas: |
|
id = data["id"] |
|
texts = data["value"]["texts"] |
|
text = "。".join(texts) |
|
embedding = data["value"]["embedding"] |
|
flattened_data.append({ |
|
"id": id, |
|
"text": text, |
|
"embedding": embedding |
|
}) |
|
|
|
|
|
df = pd.DataFrame(flattened_data) |
|
|
|
|
|
df.to_parquet(file_path, index=False) |
|
|
|
import pandas as pd |
|
import os |
|
|
|
def get_id2embedding(regen=False, parquet_file='datas/qa_with_embedding.parquet'): |
|
""" |
|
Get a dictionary mapping IDs to embeddings. Regenerate embeddings if specified. |
|
|
|
Parameters: |
|
- parquet_file (str): The path to the Parquet file. |
|
- regen (bool): Whether to regenerate embeddings. |
|
|
|
Returns: |
|
- dict: A dictionary mapping IDs to list of float embeddings. |
|
""" |
|
if regen or not os.path.exists(parquet_file): |
|
print("Regenerating embeddings...") |
|
|
|
model_name = 'BAAI/bge-small-zh-v1.5' |
|
text_extractor = TextExtractor(model_name) |
|
|
|
datas = get_qas() |
|
print("Extracting embeddings for", len(datas), "data items") |
|
|
|
datas = extract_embedding(datas, text_extractor) |
|
save_parquet(datas, parquet_file) |
|
|
|
df = pd.read_parquet(parquet_file) |
|
|
|
id2embedding = {} |
|
for index, row in df.iterrows(): |
|
id = row['id'] |
|
embedding = row['embedding'] |
|
id2embedding[id] = embedding[0] |
|
|
|
return id2embedding |
|
|
|
import torch |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import heapq |
|
|
|
def __get_id2top30map(id2embedding): |
|
""" |
|
Get a dictionary mapping IDs to their top 30 nearest neighbors based on cosine similarity. |
|
|
|
Parameters: |
|
- id2embedding (dict): A dictionary mapping IDs to list of float embeddings. |
|
|
|
Returns: |
|
- dict: A dictionary mapping each ID to a list of the top 30 nearest neighbor IDs. |
|
""" |
|
ids = list(id2embedding.keys()) |
|
embeddings = torch.tensor([id2embedding[id] for id in ids]) |
|
|
|
|
|
cos_sim_matrix = cosine_similarity(embeddings) |
|
|
|
id2top30map = {} |
|
for i, id in enumerate(ids): |
|
|
|
sim_scores = cos_sim_matrix[i] |
|
|
|
|
|
top_indices = heapq.nlargest(31, range(len(sim_scores)), key=lambda x: sim_scores[x]) |
|
top_indices.remove(i) |
|
|
|
|
|
top_30_ids = [ids[idx] for idx in top_indices[:30]] |
|
|
|
id2top30map[id] = top_30_ids |
|
|
|
return id2top30map |
|
|
|
import pickle |
|
|
|
def get_id2top30map( id2embedding = None ): |
|
default_save_pkl = "data/id2top30map.pkl" |
|
if id2embedding is None: |
|
if os.path.exists(default_save_pkl): |
|
with open(default_save_pkl, 'rb') as f: |
|
id2top30map = pickle.load(f) |
|
else: |
|
print("No embedding found, generating new one...") |
|
id2embedding = get_id2embedding(regen=False) |
|
id2top30map = __get_id2top30map(id2embedding) |
|
with open(default_save_pkl, 'wb') as f: |
|
pickle.dump(id2top30map, f) |
|
else: |
|
id2top30map = __get_id2top30map(id2embedding) |
|
|
|
return id2top30map |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
if False: |
|
|
|
model_name = 'BAAI/bge-small-zh-v1.5' |
|
sentences = ["样例数据-1", "样例数据-2"] |
|
|
|
text_extractor = TextExtractor(model_name) |
|
embeddings = text_extractor.extract(sentences) |
|
print("Sentence embeddings:", embeddings) |
|
|
|
datas = get_qas() |
|
|
|
print("extract embedding for ", len(datas), " datas") |
|
|
|
datas = extract_embedding(datas, text_extractor ) |
|
|
|
default_parquet_save_name = "data/qa_with_embedding.parquet" |
|
|
|
save_parquet(datas, default_parquet_save_name) |
|
if True: |
|
id2embedding = get_id2embedding(regen=False) |
|
print(len(id2embedding[4])) |
|
id2top30map = get_id2top30map( None ) |
|
print("ID to Top 30 Neighbors dictionary:", id2top30map[4]) |
|
|
|
if True: |
|
|
|
start_id = 332 |
|
visited_ids = [start_id] |
|
current_queue = [start_id] |
|
|
|
expend_num = 5 |
|
|
|
for iteration in range(10): |
|
current_node = current_queue.pop(0) |
|
top30 = id2top30map[current_node] |
|
current_expend = [] |
|
for id in top30: |
|
if id not in visited_ids: |
|
visited_ids.append(id) |
|
current_queue.append(id) |
|
current_expend.append(id) |
|
if len(current_expend) >= expend_num: |
|
break |
|
display_text = f"{current_node} | ->" + ",".join([str(i) for i in current_expend]) |
|
print(display_text) |
|
|
|
from get_qa_and_image import get_qa_and_image |
|
image_datas = get_qa_and_image() |
|
|
|
id2index = {} |
|
|
|
for i, data in enumerate(image_datas): |
|
id2index[data['id']] = i |
|
|
|
indexes = [id2index[i] for i in visited_ids if i in id2index] |
|
image_names = [image_datas[index]['value']['image'] for index in indexes] |
|
|
|
target_copy_folder = "data/asso_collection" |
|
|
|
import shutil |
|
|
|
for image_name in image_names: |
|
shutil.copy(image_name, target_copy_folder) |