python / m5.py
Princess3's picture
Update m5.py
8417cb6 verified
raw
history blame
9.09 kB
import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, numpy as np, logging, requests
from typing import List, Dict, Any, Optional
from collections import defaultdict
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import termcolor
# Set the cache directory path
cache_dir = '/app/cache'
# Create the directory if it doesn't exist
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Set the environment variable
os.environ['TRANSFORMERS_CACHE'] = cache_dir
# Verify the environment variable is set
print(f"TRANSFORMERS_CACHE is set to: {os.environ['TRANSFORMERS_CACHE']}")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class DM(nn.Module):
def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
super(DM, self).__init__()
self.s = nn.ModuleDict()
if not s: s = {'default': [{'input_size': 128, 'output_size': 256, 'activation': 'relu', 'batch_norm': True, 'dropout': 0.1}]}
for sn, l in s.items():
self.s[sn] = nn.ModuleList()
for lp in l:
logging.info(f"Creating layer in section '{sn}' with params: {lp}")
self.s[sn].append(self.cl(lp))
def cl(self, lp: Dict[str, Any]) -> nn.Module:
l = [nn.Linear(lp['input_size'], lp['output_size'])]
if lp.get('batch_norm', True): l.append(nn.BatchNorm1d(lp['output_size']))
a = lp.get('activation', 'relu')
if a == 'relu': l.append(nn.ReLU(inplace=True))
elif a == 'tanh': l.append(nn.Tanh())
elif a == 'sigmoid': l.append(nn.Sigmoid())
elif a == 'leaky_relu': l.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
elif a == 'elu': l.append(nn.ELU(alpha=1.0, inplace=True))
elif a is not None: raise ValueError(f"Unsupported activation function: {a}")
if dr := lp.get('dropout', 0.0): l.append(nn.Dropout(p=dr))
if hl := lp.get('hidden_layers', []):
for hlp in hl: l.append(self.cl(hlp))
if lp.get('memory_augmentation', True): l.append(MAL(lp['output_size']))
if lp.get('hybrid_attention', True): l.append(HAL(lp['output_size']))
if lp.get('dynamic_flash_attention', True): l.append(DFAL(lp['output_size']))
return nn.Sequential(*l)
def forward(self, x: torch.Tensor, sn: Optional[str] = None) -> torch.Tensor:
if sn is not None:
if sn not in self.s: raise KeyError(f"Section '{sn}' not found in model")
for l in self.s[sn]: x = l(x)
else:
for sn, l in self.s.items():
for l in l: x = l(x)
return x
class MAL(nn.Module):
def __init__(self, s: int):
super(MAL, self).__init__()
self.m = nn.Parameter(torch.randn(s))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.m
class HAL(nn.Module):
def __init__(self, s: int):
super(HAL, self).__init__()
self.a = nn.MultiheadAttention(s, num_heads=8)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(1)
ao, _ = self.a(x, x, x)
return ao.squeeze(1)
class DFAL(nn.Module):
def __init__(self, s: int):
super(DFAL, self).__init__()
self.a = nn.MultiheadAttention(s, num_heads=8)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(1)
ao, _ = self.a(x, x, x)
return ao.squeeze(1)
def px(file_path: str) -> List[Dict[str, Any]]:
t = ET.parse(file_path)
r = t.getroot()
l = []
for ly in r.findall('.//layer'):
lp = {'input_size': int(ly.get('input_size', 128)), 'output_size': int(ly.get('output_size', 256)), 'activation': ly.get('activation', 'relu').lower()}
if lp['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']: raise ValueError(f"Unsupported activation function: {lp['activation']}")
if lp['input_size'] <= 0 or lp['output_size'] <= 0: raise ValueError("Layer dimensions must be positive integers")
l.append(lp)
if not l: l.append({'input_size': 128, 'output_size': 256, 'activation': 'relu'})
return l
def cmf(folder_path: str) -> DM:
s = defaultdict(list)
if not os.path.exists(folder_path):
logging.warning(f"Folder {folder_path} does not exist. Creating model with default configuration.")
return DM({})
xf = True
for r, d, f in os.walk(folder_path):
for file in f:
if file.endswith('.xml'):
xf = True
fp = os.path.join(r, file)
try:
l = px(fp)
sn = os.path.basename(r).replace('.', '_')
s[sn].extend(l)
except Exception as e:
logging.error(f"Error processing {fp}: {str(e)}")
if not xf:
logging.warning("No XML files found. Creating model with default configuration.")
return DM({})
return DM(dict(s))
def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
t = AutoTokenizer.from_pretrained(model_name)
m = AutoModel.from_pretrained(model_name)
embeddings = []
ds = []
for r, d, f in os.walk(folder_path):
for file in f:
if file.endswith('.xml'):
fp = os.path.join(r, file)
try:
tree = ET.parse(fp)
root = tree.getroot()
for e in root.iter():
if e.text:
text = e.text.strip()
i = t(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
emb = m(**i).last_hidden_state.mean(dim=1).numpy()
embeddings.append(emb)
ds.append(text)
except Exception as e:
logging.error(f"Error processing {fp}: {str(e)}")
embeddings = np.vstack(embeddings)
return embeddings, ds
def qvs(query: str, embeddings, ds, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
t = AutoTokenizer.from_pretrained(model_name)
m = AutoModel.from_pretrained(model_name)
i = t(query, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
qe = m(**i).last_hidden_state.mean(dim=1).numpy()
similarities = cosine_similarity(qe, embeddings)
top_k_indices = similarities[0].argsort()[-5:][::-1]
return [ds[i] for i in top_k_indices]
def fetch_courtlistener_data(query: str) -> List[Dict[str, Any]]:
base_url = "https://nzlii.org/cgi-bin/sinosrch.cgi"
params = {
"method": "auto",
"query": query,
"meta": "/nz",
"mask_path": "",
"results": "50",
"format": "json"
}
try:
response = requests.get(base_url, params=params, headers={"Accept": "application/json"}, timeout=10)
response.raise_for_status()
results = response.json().get("results", [])
processed_results = []
for result in results:
processed_results.append({
"title": result.get("title", ""),
"citation": result.get("citation", ""),
"date": result.get("date", ""),
"court": result.get("court", ""),
"summary": result.get("summary", ""),
"url": result.get("url", "")
})
return processed_results
except requests.exceptions.RequestException as e:
logging.error(f"Failed to fetch data from NZLII API: {str(e)}")
return []
except ValueError as e:
logging.error(f"Failed to parse NZLII API response: {str(e)}")
return []
def main():
fp = 'data'
m = cmf(fp)
logging.info(f"Created dynamic PyTorch model with sections: {list(m.s.keys())}")
fs = next(iter(m.s.keys()))
fl = m.s[fs][0]
ife = fl[0].in_features
si = torch.randn(1, ife)
o = m(si)
logging.info(f"Sample output shape: {o.shape}")
embeddings, ds = ceas(fp)
a = Accelerator()
o = torch.optim.Adam(m.parameters(), lr=0.001)
c = nn.CrossEntropyLoss()
ne = 10
d = torch.utils.data.TensorDataset(torch.randn(100, ife), torch.randint(0, 2, (100,)))
td = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True)
m, o, td = a.prepare(m, o, td)
for e in range(ne):
m.train()
tl = 0
for bi, (i, l) in enumerate(td):
o.zero_grad()
o = m(i)
l = c(o, l)
a.backward(l)
o.step()
tl += l.item()
al = tl / len(td)
logging.info(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}")
uq = "example query text"
r = qvs(uq, embeddings, ds)
logging.info(f"Query results: {r}")
cl_data = fetch_courtlistener_data(uq)
logging.info(f"CourtListener API results: {cl_data}")
if __name__ == "__main__":
main()