Spaces:
Sleeping
Sleeping
import nltk | |
from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder | |
from nltk.corpus import stopwords | |
import spacy | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from datetime import datetime, timedelta | |
import numpy as np | |
import heapq | |
from concurrent.futures import ThreadPoolExecutor | |
from annoy import AnnoyIndex | |
from transformers import pipeline | |
from rank_bm25 import BM25Okapi | |
from functools import partial | |
import ssl | |
try: | |
_create_unverified_https_context = ssl._create_unverified_context | |
except AttributeError: | |
pass | |
else: | |
ssl._create_default_https_context = _create_unverified_https_context | |
import sys | |
import subprocess | |
def download_spacy_model(model_name): | |
print(f"Downloading spaCy model: {model_name}") | |
subprocess.check_call([sys.executable, "-m", "spacy", "download", model_name]) | |
print(f"Model {model_name} downloaded successfully") | |
# Usage | |
try: | |
nlp = spacy.load('en_core_web_sm') | |
except OSError: | |
# If the model is not found, download it | |
download_spacy_model('en_core_web_sm') | |
# Try loading again | |
import spacy | |
nlp = spacy.load('en_core_web_sm') | |
# Now you can use the model | |
print("spaCy model loaded successfully") | |
nltk.download('stopwords') | |
nltk.download('punkt') | |
q = 0 | |
def get_keywords(text, cache): | |
global q | |
if q % 1000 == 0: | |
print(q) | |
q += 1 | |
if text in cache: | |
return cache[text] | |
doc = nlp(text) | |
keywords = [] | |
for token in doc: | |
if token.pos_ in ['NOUN', 'PROPN', 'VERB']: | |
keywords.append(token.text.lower()) | |
stop_words = set(stopwords.words('english')) | |
keywords = [word for word in keywords if word not in stop_words] | |
bigram_measures = BigramAssocMeasures() | |
finder = BigramCollocationFinder.from_words([token.text for token in doc]) | |
bigrams = finder.nbest(bigram_measures.pmi, 10) | |
keywords.extend([' '.join(bigram) for bigram in bigrams]) | |
cache[text] = keywords | |
return keywords | |
def calculate_weight(message, sender_messages, cache): | |
message_time = datetime.strptime(message[1], '%Y-%m-%d %H:%M:%S') | |
recent_messages = sender_messages[np.abs((np.array([datetime.strptime(m[1], '%Y-%m-%d %H:%M:%S') for m in sender_messages]) - message_time).astype('timedelta64[s]').astype(int) <= 5 * 3600)] | |
recent_keywords = [get_keywords(m[2], cache) for m in recent_messages] | |
keyword_counts = [sum([k.count(keyword) for k in recent_keywords]) for keyword in get_keywords(message[2], cache)] | |
weight = sum(keyword_counts) | |
return weight | |
class ChatDatabase: | |
def __init__(self, filename): | |
self.filename = filename | |
self.messages = [] | |
self.messages_array = None | |
self.sender_array = None | |
self.load_messages() | |
self.index = None | |
self.tfidf = None | |
def load_messages(self): | |
with open(self.filename, 'a') as file: | |
pass | |
with open(self.filename, 'r') as f: | |
for line in f: | |
parts = line.strip().split('\t') | |
if len(parts) == 4: | |
sender, time, text, tag = parts | |
else: | |
sender, time, text = parts | |
tag = None | |
message = (sender, time, text, tag) | |
self.messages.append(message) | |
self.messages_array = np.array(self.messages, dtype=object) | |
print(self.messages_array,'hihii') | |
if len(self.messages_array)==0: | |
self.sender_array = [] | |
else: | |
self.sender_array = self.messages_array[:, 0] | |
print(f'Database loaded. Number of messages: {len(self.messages_array)}') | |
def add_message(self, sender, time, text, tag=None): | |
message = np.array((sender, time, text, tag)).flatten() | |
self.messages.append(message) | |
self.messages_array = np.append(self.messages_array, message, axis=0) | |
self.sender_array = np.append(self.sender_array, sender) | |
with open(self.filename, 'a') as f: | |
f.write(f'{sender}\t{time}\t{text}\t{tag}\n') | |
def predict_response_separate(self, query, sender, cache): | |
if self.messages_array is None: | |
print("Error: messages_array is None") | |
return None | |
sender_messages = self.messages_array[self.sender_array == sender] | |
if len(sender_messages) == 0: | |
print(f"No messages found for sender: {sender}") | |
return None | |
query_keywords = ' '.join(get_keywords(query, cache)) | |
query_vector = self.tfidf.transform([query_keywords]).toarray()[0] | |
relevant_indices = self.index.get_nns_by_vector(query_vector, 1) | |
relevant_message = sender_messages[relevant_indices[0]] | |
next_message_index = np.where(self.sender_array != sender)[0][0] | |
if next_message_index < len(self.messages_array): | |
predicted_response = self.messages_array[next_message_index] | |
return tuple(predicted_response) | |
else: | |
return None | |
def get_relevant_messages(self, sender, query, N, cache, query_tag=None, n_threads=30, tag_boost=1.5): | |
if self.messages_array is None: | |
print("Error: messages_array is None") | |
return [] | |
query_keywords = query.lower().split() | |
#Filter by sender | |
sender_messages = self.messages_array[self.sender_array == sender] | |
print(f"Number of messages from sender {sender}: {len(sender_messages)}") | |
# Filter messages by sender, tag, and keywords in a single line | |
sender_messages = self.messages_array[ | |
(self.sender_array == sender) & | |
np.array([any(keyword in message.lower() for keyword in query_keywords) for message in self.messages_array[:, 2]]) | |
] | |
if len(sender_messages) == 0: | |
print(f"No messages found for sender: {sender} with the given keywords") | |
return [] | |
else: | |
print(len(sender_messages)) | |
def process_batch(batch, query_keywords, current_time, query_tag): | |
batch_keywords = [get_keywords(message[2], cache) for message in batch] | |
bm25 = BM25Okapi(batch_keywords) | |
bm25_scores = bm25.get_scores(query_keywords) | |
time_scores = 1 / (1 + (current_time - np.array([datetime.strptime(m[1], '%Y-%m-%d %H:%M:%S') for m in batch])).astype('timedelta64[D]').astype(int)) | |
tag_scores = np.where(np.array([m[3] for m in batch]) == query_tag, tag_boost, 1) | |
combined_scores = 0.6 * np.array(bm25_scores) + 0.2 * time_scores + 0.2 * tag_scores | |
return combined_scores, batch | |
current_time = datetime.now() | |
batch_size = max(1, len(sender_messages) // n_threads) | |
batches = [sender_messages[i:i+batch_size] for i in range(0, len(sender_messages), batch_size)] | |
with ThreadPoolExecutor(max_workers=n_threads) as executor: | |
process_func = partial(process_batch, query_keywords=query_keywords, current_time=current_time, query_tag=query_tag) | |
results = list(executor.map(process_func, batches)) | |
all_scores = np.concatenate([r[0] for r in results]) | |
all_messages = np.concatenate([r[1] for r in results]) | |
top_indices = np.argsort(all_scores)[-N:][::-1] | |
relevant_messages = all_messages[top_indices] | |
return relevant_messages.tolist() | |
def generate_response(self, query, sender, cache, query_tag=None): | |
relevant_messages = self.get_relevant_messages(sender, query, 5, cache, query_tag) | |
context = ' '.join([message[2] for message in relevant_messages]) | |
generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B') | |
response = generator(f'{context} {query}', max_length=100, do_sample=True)[0]['generated_text'] | |
response = response.split(query)[-1].strip() | |
return response | |
# Usage example remains the same | |
''' | |
# Usage example | |
db = ChatDatabase('memory.txt') | |
# Example 1: Get relevant messages | |
query = 'fisical' | |
sender = 'Arcana' | |
N = 10 | |
cache = {} | |
query_tag = None | |
relevant_messages = db.get_relevant_messages(sender, query, N, cache, query_tag) | |
print("Relevant messages:") | |
for message in relevant_messages: | |
print(f"Sender: {message[0]}, Time: {message[1]}, Tag: {message[3]}") | |
print(f"Message: {message[2][:100]}...") | |
print() | |
# Example 2: Predict response (using the original method) | |
query = "what was that?" | |
sender = 'David' | |
db.build_index_separate(cache) | |
predicted_response = db.predict_response_separate(query, sender, cache) | |
print("\nPredicted response:") | |
if predicted_response is not None: | |
print(f"Sender: {predicted_response[0]}, Time: {predicted_response[1]}, Tag: {predicted_response[3]}") | |
print(f"Message: {predicted_response[2][:100]}...") | |
else: | |
print('No predicted response found') | |
# Example 3: Generate response | |
query = "Let's plan a trip" | |
sender = 'Alice' | |
query_tag = 'travel' | |
generated_response = db.generate_response(query, sender, cache, query_tag) | |
print("\nGenerated response:") | |
print(generated_response) | |
''' | |