Spaces:
Runtime error
Runtime error
File size: 9,086 Bytes
c593750 8417cb6 c593750 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, numpy as np, logging, requests
from typing import List, Dict, Any, Optional
from collections import defaultdict
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import termcolor
# Set the cache directory path
cache_dir = '/app/cache'
# Create the directory if it doesn't exist
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Set the environment variable
os.environ['TRANSFORMERS_CACHE'] = cache_dir
# Verify the environment variable is set
print(f"TRANSFORMERS_CACHE is set to: {os.environ['TRANSFORMERS_CACHE']}")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class DM(nn.Module):
def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
super(DM, self).__init__()
self.s = nn.ModuleDict()
if not s: s = {'default': [{'input_size': 128, 'output_size': 256, 'activation': 'relu', 'batch_norm': True, 'dropout': 0.1}]}
for sn, l in s.items():
self.s[sn] = nn.ModuleList()
for lp in l:
logging.info(f"Creating layer in section '{sn}' with params: {lp}")
self.s[sn].append(self.cl(lp))
def cl(self, lp: Dict[str, Any]) -> nn.Module:
l = [nn.Linear(lp['input_size'], lp['output_size'])]
if lp.get('batch_norm', True): l.append(nn.BatchNorm1d(lp['output_size']))
a = lp.get('activation', 'relu')
if a == 'relu': l.append(nn.ReLU(inplace=True))
elif a == 'tanh': l.append(nn.Tanh())
elif a == 'sigmoid': l.append(nn.Sigmoid())
elif a == 'leaky_relu': l.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
elif a == 'elu': l.append(nn.ELU(alpha=1.0, inplace=True))
elif a is not None: raise ValueError(f"Unsupported activation function: {a}")
if dr := lp.get('dropout', 0.0): l.append(nn.Dropout(p=dr))
if hl := lp.get('hidden_layers', []):
for hlp in hl: l.append(self.cl(hlp))
if lp.get('memory_augmentation', True): l.append(MAL(lp['output_size']))
if lp.get('hybrid_attention', True): l.append(HAL(lp['output_size']))
if lp.get('dynamic_flash_attention', True): l.append(DFAL(lp['output_size']))
return nn.Sequential(*l)
def forward(self, x: torch.Tensor, sn: Optional[str] = None) -> torch.Tensor:
if sn is not None:
if sn not in self.s: raise KeyError(f"Section '{sn}' not found in model")
for l in self.s[sn]: x = l(x)
else:
for sn, l in self.s.items():
for l in l: x = l(x)
return x
class MAL(nn.Module):
def __init__(self, s: int):
super(MAL, self).__init__()
self.m = nn.Parameter(torch.randn(s))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.m
class HAL(nn.Module):
def __init__(self, s: int):
super(HAL, self).__init__()
self.a = nn.MultiheadAttention(s, num_heads=8)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(1)
ao, _ = self.a(x, x, x)
return ao.squeeze(1)
class DFAL(nn.Module):
def __init__(self, s: int):
super(DFAL, self).__init__()
self.a = nn.MultiheadAttention(s, num_heads=8)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(1)
ao, _ = self.a(x, x, x)
return ao.squeeze(1)
def px(file_path: str) -> List[Dict[str, Any]]:
t = ET.parse(file_path)
r = t.getroot()
l = []
for ly in r.findall('.//layer'):
lp = {'input_size': int(ly.get('input_size', 128)), 'output_size': int(ly.get('output_size', 256)), 'activation': ly.get('activation', 'relu').lower()}
if lp['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']: raise ValueError(f"Unsupported activation function: {lp['activation']}")
if lp['input_size'] <= 0 or lp['output_size'] <= 0: raise ValueError("Layer dimensions must be positive integers")
l.append(lp)
if not l: l.append({'input_size': 128, 'output_size': 256, 'activation': 'relu'})
return l
def cmf(folder_path: str) -> DM:
s = defaultdict(list)
if not os.path.exists(folder_path):
logging.warning(f"Folder {folder_path} does not exist. Creating model with default configuration.")
return DM({})
xf = True
for r, d, f in os.walk(folder_path):
for file in f:
if file.endswith('.xml'):
xf = True
fp = os.path.join(r, file)
try:
l = px(fp)
sn = os.path.basename(r).replace('.', '_')
s[sn].extend(l)
except Exception as e:
logging.error(f"Error processing {fp}: {str(e)}")
if not xf:
logging.warning("No XML files found. Creating model with default configuration.")
return DM({})
return DM(dict(s))
def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
t = AutoTokenizer.from_pretrained(model_name)
m = AutoModel.from_pretrained(model_name)
embeddings = []
ds = []
for r, d, f in os.walk(folder_path):
for file in f:
if file.endswith('.xml'):
fp = os.path.join(r, file)
try:
tree = ET.parse(fp)
root = tree.getroot()
for e in root.iter():
if e.text:
text = e.text.strip()
i = t(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
emb = m(**i).last_hidden_state.mean(dim=1).numpy()
embeddings.append(emb)
ds.append(text)
except Exception as e:
logging.error(f"Error processing {fp}: {str(e)}")
embeddings = np.vstack(embeddings)
return embeddings, ds
def qvs(query: str, embeddings, ds, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
t = AutoTokenizer.from_pretrained(model_name)
m = AutoModel.from_pretrained(model_name)
i = t(query, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
qe = m(**i).last_hidden_state.mean(dim=1).numpy()
similarities = cosine_similarity(qe, embeddings)
top_k_indices = similarities[0].argsort()[-5:][::-1]
return [ds[i] for i in top_k_indices]
def fetch_courtlistener_data(query: str) -> List[Dict[str, Any]]:
base_url = "https://nzlii.org/cgi-bin/sinosrch.cgi"
params = {
"method": "auto",
"query": query,
"meta": "/nz",
"mask_path": "",
"results": "50",
"format": "json"
}
try:
response = requests.get(base_url, params=params, headers={"Accept": "application/json"}, timeout=10)
response.raise_for_status()
results = response.json().get("results", [])
processed_results = []
for result in results:
processed_results.append({
"title": result.get("title", ""),
"citation": result.get("citation", ""),
"date": result.get("date", ""),
"court": result.get("court", ""),
"summary": result.get("summary", ""),
"url": result.get("url", "")
})
return processed_results
except requests.exceptions.RequestException as e:
logging.error(f"Failed to fetch data from NZLII API: {str(e)}")
return []
except ValueError as e:
logging.error(f"Failed to parse NZLII API response: {str(e)}")
return []
def main():
fp = 'data'
m = cmf(fp)
logging.info(f"Created dynamic PyTorch model with sections: {list(m.s.keys())}")
fs = next(iter(m.s.keys()))
fl = m.s[fs][0]
ife = fl[0].in_features
si = torch.randn(1, ife)
o = m(si)
logging.info(f"Sample output shape: {o.shape}")
embeddings, ds = ceas(fp)
a = Accelerator()
o = torch.optim.Adam(m.parameters(), lr=0.001)
c = nn.CrossEntropyLoss()
ne = 10
d = torch.utils.data.TensorDataset(torch.randn(100, ife), torch.randint(0, 2, (100,)))
td = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True)
m, o, td = a.prepare(m, o, td)
for e in range(ne):
m.train()
tl = 0
for bi, (i, l) in enumerate(td):
o.zero_grad()
o = m(i)
l = c(o, l)
a.backward(l)
o.step()
tl += l.item()
al = tl / len(td)
logging.info(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}")
uq = "example query text"
r = qvs(uq, embeddings, ds)
logging.info(f"Query results: {r}")
cl_data = fetch_courtlistener_data(uq)
logging.info(f"CourtListener API results: {cl_data}")
if __name__ == "__main__":
main() |