|
import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, numpy as np, logging, requests |
|
from collections import defaultdict |
|
from torch.utils.data import DataLoader, Dataset, TensorDataset |
|
from transformers import AutoTokenizer, AutoModel |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from accelerate import Accelerator |
|
from tqdm import tqdm |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
class Config: |
|
E, H, N, C, B = 512, 32, 1024, 256, 128 |
|
M, S, V = 20000, 2048, 1e5 |
|
W, L, D = 4000, 2e-4, .15 |
|
|
|
|
|
class MyDataset(Dataset): |
|
def __init__(self, data, labels): |
|
self.data = data |
|
self.labels = labels |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
return self.data[index], self.labels[index] |
|
|
|
|
|
class MyModel(nn.Module): |
|
def __init__(self, input_size, hidden_size, output_size): |
|
super(MyModel, self).__init__() |
|
self.hidden = nn.Linear(input_size, hidden_size) |
|
self.output = nn.Linear(hidden_size, output_size) |
|
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) |
|
self.fc = nn.Linear(hidden_size, output_size) |
|
|
|
def forward(self, x): |
|
x = torch.relu(self.hidden(x)) |
|
h0 = torch.zeros(1, x.size(0), hidden_size) |
|
c0 = torch.zeros(1, x.size(0), hidden_size) |
|
out, _ = self.lstm(x, (h0, c0)) |
|
out = self.fc(out[:, -1, :]) |
|
return out |
|
|
|
|
|
class MemoryNetwork: |
|
def __init__(self, memory_size, embedding_size): |
|
self.memory_size = memory_size |
|
self.embedding_size = embedding_size |
|
self.memory = np.zeros((memory_size, embedding_size)) |
|
self.usage = np.zeros(memory_size) |
|
|
|
def store(self, data): |
|
index = np.argmin(self.usage) |
|
self.memory[index] = data |
|
self.usage[index] = 1.0 |
|
|
|
def retrieve(self, query): |
|
similarities = np.dot(self.memory, query) |
|
index = np.argmax(similarities) |
|
self.usage[index] += 1.0 |
|
return self.memory[index] |
|
|
|
def update_usage(self): |
|
self.usage *= 0.99 |
|
|
|
|
|
class DM(nn.Module): |
|
def __init__(self, s): |
|
super(DM, self).__init__() |
|
self.s = nn.ModuleDict() |
|
for sn, l in s.items(): |
|
self.s[sn] = nn.ModuleList([self.cl(lp) for lp in l]) |
|
|
|
def cl(self, lp): |
|
l = [nn.Linear(lp['input_size'], lp['output_size'])] |
|
if lp.get('batch_norm', True): l.append(nn.BatchNorm1d(lp['output_size'])) |
|
a = lp.get('activation', 'relu') |
|
if a == 'relu': l.append(nn.ReLU(inplace=True)) |
|
elif a == 'tanh': l.append(nn.Tanh()) |
|
elif a == 'sigmoid': l.append(nn.Sigmoid()) |
|
elif a == 'leaky_relu': l.append(nn.LeakyReLU(negative_slope=0.01, inplace=True)) |
|
elif a == 'elu': l.append(nn.ELU(alpha=1.0, inplace=True)) |
|
if dr := lp.get('dropout', 0.0): l.append(nn.Dropout(p=dr)) |
|
return nn.Sequential(*l) |
|
|
|
def forward(self, x, sn=None): |
|
if sn is not None: |
|
for l in self.s[sn]: x = l(x) |
|
else: |
|
for sn, l in self.s.items(): |
|
for l in l: x = l(x) |
|
return x |
|
|
|
|
|
def parse_xml(file_path): |
|
t = ET.parse(file_path) |
|
r = t.getroot() |
|
l = [] |
|
for ly in r.findall('.//layer'): |
|
lp = {'input_size': int(ly.get('input_size', 128)), 'output_size': int(ly.get('output_size', 256)), 'activation': ly.get('activation', 'relu').lower()} |
|
l.append(lp) |
|
return l |
|
|
|
|
|
def create_model_from_folder(folder_path): |
|
s = defaultdict(list) |
|
for r, d, f in os.walk(folder_path): |
|
for file in f: |
|
if file.endswith('.xml'): |
|
fp = os.path.join(r, file) |
|
l = parse_xml(fp) |
|
sn = os.path.basename(r).replace('.', '_') |
|
s[sn].extend(l) |
|
return DM(dict(s)) |
|
|
|
|
|
def create_embeddings_and_sentences(folder_path, model_name="sentence-transformers/all-MiniLM-L6-v2"): |
|
t = AutoTokenizer.from_pretrained(model_name) |
|
m = AutoModel.from_pretrained(model_name) |
|
embeddings, ds = [], [] |
|
for r, d, f in os.walk(folder_path): |
|
for file in f: |
|
if file.endswith('.xml'): |
|
fp = os.path.join(r, file) |
|
tree = ET.parse(fp) |
|
root = tree.getroot() |
|
for e in root.iter(): |
|
if e.text: |
|
text = e.text.strip() |
|
i = t(text, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
emb = m(**i).last_hidden_state.mean(dim=1).numpy() |
|
embeddings.append(emb) |
|
ds.append(text) |
|
embeddings = np.vstack(embeddings) |
|
return embeddings, ds |
|
|
|
|
|
def query_vector_similarity(query, embeddings, ds, model_name="sentence-transformers/all-MiniLM-L6-v2"): |
|
t = AutoTokenizer.from_pretrained(model_name) |
|
m = AutoModel.from_pretrained(model_name) |
|
i = t(query, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
qe = m(**i).last_hidden_state.mean(dim=1).numpy() |
|
similarities = cosine_similarity(qe, embeddings) |
|
top_k_indices = similarities[0].argsort()[-5:][::-1] |
|
return [ds[i] for i in top_k_indices] |
|
|
|
|
|
def fetch_courtlistener_data(query): |
|
base_url = "https://nzlii.org/cgi-bin/sinosrch.cgi" |
|
params = {"method": "auto", "query": query, "meta": "/nz", "results": "50", "format": "json"} |
|
try: |
|
response = requests.get(base_url, params=params, headers={"Accept": "application/json"}, timeout=10) |
|
response.raise_for_status() |
|
results = response.json().get("results", []) |
|
return [{"title": r.get("title", ""), "citation": r.get("citation", ""), "date": r.get("date", ""), "court": r.get("court", ""), "summary": r.get("summary", ""), "url": r.get("url", "")} for r in results] |
|
except requests.exceptions.RequestException as e: |
|
logging.error(f"Failed to fetch data from NZLII API: {str(e)}") |
|
return [] |
|
|
|
|
|
def main(): |
|
folder_path = 'data' |
|
model = create_model_from_folder(folder_path) |
|
logging.info(f"Created dynamic PyTorch model with sections: {list(model.s.keys())}") |
|
embeddings, ds = create_embeddings_and_sentences(folder_path) |
|
accelerator = Accelerator() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
|
criterion = nn.CrossEntropyLoss() |
|
num_epochs = 10 |
|
dataset = MyDataset(torch.randn(1000, 10), torch.randint(0, 5, (1000,))) |
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) |
|
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) |
|
for epoch in range(num_epochs): |
|
model.train() |
|
for batch_data, batch_labels in dataloader: |
|
optimizer.zero_grad() |
|
outputs = model(batch_data) |
|
loss = criterion(outputs, batch_labels) |
|
accelerator.backward(loss) |
|
optimizer.step() |
|
logging.info(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}") |
|
query = "example query text" |
|
results = query_vector_similarity(query, embeddings, ds) |
|
logging.info(f"Query results: {results}") |
|
courtlistener_data = fetch_courtlistener_data(query) |
|
logging.info(f"CourtListener API results: {courtlistener_data}") |
|
|
|
if __name__ == "__main__": |
|
main() |