File size: 9,086 Bytes
c593750
 
 
 
 
 
8417cb6
 
 
 
 
 
 
 
 
 
 
 
 
 
c593750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, numpy as np, logging, requests
from typing import List, Dict, Any, Optional
from collections import defaultdict
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import termcolor

# Set the cache directory path
cache_dir = '/app/cache'

# Create the directory if it doesn't exist
if not os.path.exists(cache_dir):
    os.makedirs(cache_dir)

# Set the environment variable
os.environ['TRANSFORMERS_CACHE'] = cache_dir

# Verify the environment variable is set
print(f"TRANSFORMERS_CACHE is set to: {os.environ['TRANSFORMERS_CACHE']}")

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class DM(nn.Module):
    def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
        super(DM, self).__init__()
        self.s = nn.ModuleDict()
        if not s: s = {'default': [{'input_size': 128, 'output_size': 256, 'activation': 'relu', 'batch_norm': True, 'dropout': 0.1}]}
        for sn, l in s.items():
            self.s[sn] = nn.ModuleList()
            for lp in l: 
                logging.info(f"Creating layer in section '{sn}' with params: {lp}")
                self.s[sn].append(self.cl(lp))

    def cl(self, lp: Dict[str, Any]) -> nn.Module:
        l = [nn.Linear(lp['input_size'], lp['output_size'])]
        if lp.get('batch_norm', True): l.append(nn.BatchNorm1d(lp['output_size']))
        a = lp.get('activation', 'relu')
        if a == 'relu': l.append(nn.ReLU(inplace=True))
        elif a == 'tanh': l.append(nn.Tanh())
        elif a == 'sigmoid': l.append(nn.Sigmoid())
        elif a == 'leaky_relu': l.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
        elif a == 'elu': l.append(nn.ELU(alpha=1.0, inplace=True))
        elif a is not None: raise ValueError(f"Unsupported activation function: {a}")
        if dr := lp.get('dropout', 0.0): l.append(nn.Dropout(p=dr))
        if hl := lp.get('hidden_layers', []):
            for hlp in hl: l.append(self.cl(hlp))
        if lp.get('memory_augmentation', True): l.append(MAL(lp['output_size']))
        if lp.get('hybrid_attention', True): l.append(HAL(lp['output_size']))
        if lp.get('dynamic_flash_attention', True): l.append(DFAL(lp['output_size']))
        return nn.Sequential(*l)

    def forward(self, x: torch.Tensor, sn: Optional[str] = None) -> torch.Tensor:
        if sn is not None:
            if sn not in self.s: raise KeyError(f"Section '{sn}' not found in model")
            for l in self.s[sn]: x = l(x)
        else:
            for sn, l in self.s.items():
                for l in l: x = l(x)
        return x

class MAL(nn.Module):
    def __init__(self, s: int):
        super(MAL, self).__init__()
        self.m = nn.Parameter(torch.randn(s))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.m

class HAL(nn.Module):
    def __init__(self, s: int):
        super(HAL, self).__init__()
        self.a = nn.MultiheadAttention(s, num_heads=8)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.unsqueeze(1)
        ao, _ = self.a(x, x, x)
        return ao.squeeze(1)

class DFAL(nn.Module):
    def __init__(self, s: int):
        super(DFAL, self).__init__()
        self.a = nn.MultiheadAttention(s, num_heads=8)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.unsqueeze(1)
        ao, _ = self.a(x, x, x)
        return ao.squeeze(1)

def px(file_path: str) -> List[Dict[str, Any]]:
    t = ET.parse(file_path)
    r = t.getroot()
    l = []
    for ly in r.findall('.//layer'):
        lp = {'input_size': int(ly.get('input_size', 128)), 'output_size': int(ly.get('output_size', 256)), 'activation': ly.get('activation', 'relu').lower()}
        if lp['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']: raise ValueError(f"Unsupported activation function: {lp['activation']}")
        if lp['input_size'] <= 0 or lp['output_size'] <= 0: raise ValueError("Layer dimensions must be positive integers")
        l.append(lp)
    if not l: l.append({'input_size': 128, 'output_size': 256, 'activation': 'relu'})
    return l

def cmf(folder_path: str) -> DM:
    s = defaultdict(list)
    if not os.path.exists(folder_path):
        logging.warning(f"Folder {folder_path} does not exist. Creating model with default configuration.")
        return DM({})
    xf = True
    for r, d, f in os.walk(folder_path):
        for file in f:
            if file.endswith('.xml'):
                xf = True
                fp = os.path.join(r, file)
                try:
                    l = px(fp)
                    sn = os.path.basename(r).replace('.', '_')
                    s[sn].extend(l)
                except Exception as e:
                    logging.error(f"Error processing {fp}: {str(e)}")
    if not xf:
        logging.warning("No XML files found. Creating model with default configuration.")
        return DM({})
    return DM(dict(s))

def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
    t = AutoTokenizer.from_pretrained(model_name)
    m = AutoModel.from_pretrained(model_name)
    embeddings = []
    ds = []
    for r, d, f in os.walk(folder_path):
        for file in f:
            if file.endswith('.xml'):
                fp = os.path.join(r, file)
                try:
                    tree = ET.parse(fp)
                    root = tree.getroot()
                    for e in root.iter():
                        if e.text:
                            text = e.text.strip()
                            i = t(text, return_tensors="pt", truncation=True, padding=True)
                            with torch.no_grad():
                                emb = m(**i).last_hidden_state.mean(dim=1).numpy()
                            embeddings.append(emb)
                            ds.append(text)
                except Exception as e:
                    logging.error(f"Error processing {fp}: {str(e)}")
    embeddings = np.vstack(embeddings)
    return embeddings, ds

def qvs(query: str, embeddings, ds, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
    t = AutoTokenizer.from_pretrained(model_name)
    m = AutoModel.from_pretrained(model_name)
    i = t(query, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        qe = m(**i).last_hidden_state.mean(dim=1).numpy()
    similarities = cosine_similarity(qe, embeddings)
    top_k_indices = similarities[0].argsort()[-5:][::-1]
    return [ds[i] for i in top_k_indices]

def fetch_courtlistener_data(query: str) -> List[Dict[str, Any]]:
    base_url = "https://nzlii.org/cgi-bin/sinosrch.cgi"
    params = {
        "method": "auto",
        "query": query,
        "meta": "/nz",
        "mask_path": "",
        "results": "50",
        "format": "json"
    }
    try:
        response = requests.get(base_url, params=params, headers={"Accept": "application/json"}, timeout=10)
        response.raise_for_status()
        results = response.json().get("results", [])
        processed_results = []
        for result in results:
            processed_results.append({
                "title": result.get("title", ""),
                "citation": result.get("citation", ""),
                "date": result.get("date", ""),
                "court": result.get("court", ""),
                "summary": result.get("summary", ""),
                "url": result.get("url", "")
            })
        return processed_results
    except requests.exceptions.RequestException as e:
        logging.error(f"Failed to fetch data from NZLII API: {str(e)}")
        return []
    except ValueError as e:
        logging.error(f"Failed to parse NZLII API response: {str(e)}")
        return []

def main():
    fp = 'data'
    m = cmf(fp)
    logging.info(f"Created dynamic PyTorch model with sections: {list(m.s.keys())}")
    fs = next(iter(m.s.keys()))
    fl = m.s[fs][0]
    ife = fl[0].in_features
    si = torch.randn(1, ife)
    o = m(si)
    logging.info(f"Sample output shape: {o.shape}")
    embeddings, ds = ceas(fp)
    a = Accelerator()
    o = torch.optim.Adam(m.parameters(), lr=0.001)
    c = nn.CrossEntropyLoss()
    ne = 10
    d = torch.utils.data.TensorDataset(torch.randn(100, ife), torch.randint(0, 2, (100,)))
    td = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True)
    m, o, td = a.prepare(m, o, td)
    for e in range(ne):
        m.train()
        tl = 0
        for bi, (i, l) in enumerate(td):
            o.zero_grad()
            o = m(i)
            l = c(o, l)
            a.backward(l)
            o.step()
            tl += l.item()
        al = tl / len(td)
        logging.info(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}")
    uq = "example query text"
    r = qvs(uq, embeddings, ds)
    logging.info(f"Query results: {r}")

    cl_data = fetch_courtlistener_data(uq)
    logging.info(f"CourtListener API results: {cl_data}")

if __name__ == "__main__":
    main()