Princess3 commited on
Commit
ad52c1a
1 Parent(s): 3af37cb

Upload m3.py

Browse files
Files changed (1) hide show
  1. m3.py +176 -0
m3.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, faiss, numpy as np
2
+ from typing import List, Dict, Any, Optional
3
+ from collections import defaultdict
4
+ from accelerate import Accelerator
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from termcolor import colored
7
+
8
+ class DM(nn.Module):
9
+ def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
10
+ super(DM, self).__init__()
11
+ self.s = nn.ModuleDict()
12
+ if not s: s = {'default': [{'input_size': 128, 'output_size': 256, 'activation': 'relu', 'batch_norm': True, 'dropout': 0.1}]}
13
+ for sn, l in s.items():
14
+ self.s[sn] = nn.ModuleList()
15
+ for lp in l:
16
+ print(colored(f"Creating layer in section '{sn}' with params: {lp}", 'cyan'))
17
+ self.s[sn].append(self.cl(lp))
18
+
19
+ def cl(self, lp: Dict[str, Any]) -> nn.Module:
20
+ l = [nn.Linear(lp['input_size'], lp['output_size'])]
21
+ if lp.get('batch_norm', True): l.append(nn.BatchNorm1d(lp['output_size']))
22
+ a = lp.get('activation', 'relu')
23
+ if a == 'relu': l.append(nn.ReLU(inplace=True))
24
+ elif a == 'tanh': l.append(nn.Tanh())
25
+ elif a == 'sigmoid': l.append(nn.Sigmoid())
26
+ elif a == 'leaky_relu': l.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
27
+ elif a == 'elu': l.append(nn.ELU(alpha=1.0, inplace=True))
28
+ elif a is not None: raise ValueError(f"Unsupported activation function: {a}")
29
+ if dr := lp.get('dropout', 0.0): l.append(nn.Dropout(p=dr))
30
+ if hl := lp.get('hidden_layers', []):
31
+ for hlp in hl: l.append(self.cl(hlp))
32
+ if lp.get('memory_augmentation', True): l.append(MAL(lp['output_size']))
33
+ if lp.get('hybrid_attention', True): l.append(HAL(lp['output_size']))
34
+ if lp.get('dynamic_flash_attention', True): l.append(DFAL(lp['output_size']))
35
+ return nn.Sequential(*l)
36
+
37
+ def forward(self, x: torch.Tensor, sn: Optional[str] = None) -> torch.Tensor:
38
+ if sn is not None:
39
+ if sn not in self.s: raise KeyError(f"Section '{sn}' not found in model")
40
+ for l in self.s[sn]: x = l(x)
41
+ else:
42
+ for sn, l in self.s.items():
43
+ for l in l: x = l(x)
44
+ return x
45
+
46
+ class MAL(nn.Module):
47
+ def __init__(self, s: int):
48
+ super(MAL, self).__init__()
49
+ self.m = nn.Parameter(torch.randn(s))
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ return x + self.m
53
+
54
+ class HAL(nn.Module):
55
+ def __init__(self, s: int):
56
+ super(HAL, self).__init__()
57
+ self.a = nn.MultiheadAttention(s, num_heads=8)
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = x.unsqueeze(1)
61
+ ao, _ = self.a(x, x, x)
62
+ return ao.squeeze(1)
63
+
64
+ class DFAL(nn.Module):
65
+ def __init__(self, s: int):
66
+ super(DFAL, self).__init__()
67
+ self.a = nn.MultiheadAttention(s, num_heads=8)
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ x = x.unsqueeze(1)
71
+ ao, _ = self.a(x, x, x)
72
+ return ao.squeeze(1)
73
+
74
+ def px(file_path: str) -> List[Dict[str, Any]]:
75
+ t = ET.parse(file_path)
76
+ r = t.getroot()
77
+ l = []
78
+ for ly in r.findall('.//layer'):
79
+ lp = {'input_size': int(ly.get('input_size', 128)), 'output_size': int(ly.get('output_size', 256)), 'activation': ly.get('activation', 'relu').lower()}
80
+ if lp['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']: raise ValueError(f"Unsupported activation function: {lp['activation']}")
81
+ if lp['input_size'] <= 0 or lp['output_size'] <= 0: raise ValueError("Layer dimensions must be positive integers")
82
+ l.append(lp)
83
+ if not l: l.append({'input_size': 128, 'output_size': 256, 'activation': 'relu'})
84
+ return l
85
+
86
+ def cmf(folder_path: str) -> DM:
87
+ s = defaultdict(list)
88
+ if not os.path.exists(folder_path):
89
+ print(colored(f"Warning: Folder {folder_path} does not exist. Creating model with default configuration.", 'yellow'))
90
+ return DM({})
91
+ xf = True
92
+ for r, d, f in os.walk(folder_path):
93
+ for file in f:
94
+ if file.endswith('.xml'):
95
+ xf = True
96
+ fp = os.path.join(r, file)
97
+ try:
98
+ l = px(fp)
99
+ sn = os.path.basename(r).replace('.', '_')
100
+ s[sn].extend(l)
101
+ except Exception as e:
102
+ print(colored(f"Error processing {fp}: {str(e)}", 'red'))
103
+ if not xf:
104
+ print(colored("Warning: No XML files found. Creating model with default configuration.", 'yellow'))
105
+ return DM({})
106
+ return DM(dict(s))
107
+
108
+ def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
109
+ t = AutoTokenizer.from_pretrained(model_name)
110
+ m = AutoModel.from_pretrained(model_name)
111
+ vs = faiss.IndexFlatL2(384)
112
+ ds = []
113
+ for r, d, f in os.walk(folder_path):
114
+ for file in f:
115
+ if file.endswith('.xml'):
116
+ fp = os.path.join(r, file)
117
+ try:
118
+ tree = ET.parse(fp)
119
+ root = tree.getroot()
120
+ for e in root.iter():
121
+ if e.text:
122
+ text = e.text.strip()
123
+ i = t(text, return_tensors="pt", truncation=True, padding=True)
124
+ with torch.no_grad():
125
+ emb = m(**i).last_hidden_state.mean(dim=1).numpy()
126
+ vs.add(emb)
127
+ ds.append(text)
128
+ except Exception as e:
129
+ print(colored(f"Error processing {fp}: {str(e)}", 'red'))
130
+ return vs, ds
131
+
132
+ def qvs(query: str, vs, ds, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
133
+ t = AutoTokenizer.from_pretrained(model_name)
134
+ m = AutoModel.from_pretrained(model_name)
135
+ i = t(query, return_tensors="pt", truncation=True, padding=True)
136
+ with torch.no_grad():
137
+ qe = m(**i).last_hidden_state.mean(dim=1).numpy()
138
+ D, I = vs.search(qe, k=5)
139
+ return [ds[i] for i in I[0]]
140
+
141
+ def main():
142
+ fp = 'data'
143
+ m = cmf(fp)
144
+ print(colored(f"Created dynamic PyTorch model with sections: {list(m.s.keys())}", 'green'))
145
+ fs = next(iter(m.s.keys()))
146
+ fl = m.s[fs][0]
147
+ ife = fl[0].in_features
148
+ si = torch.randn(1, ife)
149
+ o = m(si)
150
+ print(colored(f"Sample output shape: {o.shape}", 'green'))
151
+ vs, ds = ceas(fp)
152
+ a = Accelerator()
153
+ o = torch.optim.Adam(m.parameters(), lr=0.001)
154
+ c = nn.CrossEntropyLoss()
155
+ ne = 10
156
+ d = torch.utils.data.TensorDataset(torch.randn(100, ife), torch.randint(0, 2, (100,)))
157
+ td = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True)
158
+ m, o, td = a.prepare(m, o, td)
159
+ for e in range(ne):
160
+ m.train()
161
+ tl = 0
162
+ for bi, (i, l) in enumerate(td):
163
+ o.zero_grad()
164
+ o = m(i)
165
+ l = c(o, l)
166
+ a.backward(l)
167
+ o.step()
168
+ tl += l.item()
169
+ al = tl / len(td)
170
+ print(colored(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}", 'blue'))
171
+ uq = "example query text"
172
+ r = qvs(uq, vs, ds)
173
+ print(colored(f"Query results: {r}", 'magenta'))
174
+
175
+ if __name__ == "__main__":
176
+ main()