python / m3.py
Princess3's picture
Upload m3.py
2457015 verified
raw
history blame
7.42 kB
import os, xml.etree.ElementTree as ET, torch, torch.nn as nn, torch.nn.functional as F, numpy as np
from typing import List, Dict, Any, Optional
from collections import defaultdict
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModel
from termcolor import colored
from sklearn.metrics.pairwise import cosine_similarity
class DM(nn.Module):
def __init__(self, s: Dict[str, List[Dict[str, Any]]]):
super(DM, self).__init__()
self.s = nn.ModuleDict()
if not s: s = {'default': [{'input_size': 128, 'output_size': 256, 'activation': 'relu', 'batch_norm': True, 'dropout': 0.1}]}
for sn, l in s.items():
self.s[sn] = nn.ModuleList()
for lp in l:
print(colored(f"Creating layer in section '{sn}' with params: {lp}", 'cyan'))
self.s[sn].append(self.cl(lp))
def cl(self, lp: Dict[str, Any]) -> nn.Module:
l = [nn.Linear(lp['input_size'], lp['output_size'])]
if lp.get('batch_norm', True): l.append(nn.BatchNorm1d(lp['output_size']))
a = lp.get('activation', 'relu')
if a == 'relu': l.append(nn.ReLU(inplace=True))
elif a == 'tanh': l.append(nn.Tanh())
elif a == 'sigmoid': l.append(nn.Sigmoid())
elif a == 'leaky_relu': l.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
elif a == 'elu': l.append(nn.ELU(alpha=1.0, inplace=True))
elif a is not None: raise ValueError(f"Unsupported activation function: {a}")
if dr := lp.get('dropout', 0.0): l.append(nn.Dropout(p=dr))
if hl := lp.get('hidden_layers', []):
for hlp in hl: l.append(self.cl(hlp))
if lp.get('memory_augmentation', True): l.append(MAL(lp['output_size']))
if lp.get('hybrid_attention', True): l.append(HAL(lp['output_size']))
if lp.get('dynamic_flash_attention', True): l.append(DFAL(lp['output_size']))
return nn.Sequential(*l)
def forward(self, x: torch.Tensor, sn: Optional[str] = None) -> torch.Tensor:
if sn is not None:
if sn not in self.s: raise KeyError(f"Section '{sn}' not found in model")
for l in self.s[sn]: x = l(x)
else:
for sn, l in self.s.items():
for l in l: x = l(x)
return x
class MAL(nn.Module):
def __init__(self, s: int):
super(MAL, self).__init__()
self.m = nn.Parameter(torch.randn(s))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.m
class HAL(nn.Module):
def __init__(self, s: int):
super(HAL, self).__init__()
self.a = nn.MultiheadAttention(s, num_heads=8)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(1)
ao, _ = self.a(x, x, x)
return ao.squeeze(1)
class DFAL(nn.Module):
def __init__(self, s: int):
super(DFAL, self).__init__()
self.a = nn.MultiheadAttention(s, num_heads=8)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(1)
ao, _ = self.a(x, x, x)
return ao.squeeze(1)
def px(file_path: str) -> List[Dict[str, Any]]:
t = ET.parse(file_path)
r = t.getroot()
l = []
for ly in r.findall('.//layer'):
lp = {'input_size': int(ly.get('input_size', 128)), 'output_size': int(ly.get('output_size', 256)), 'activation': ly.get('activation', 'relu').lower()}
if lp['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']: raise ValueError(f"Unsupported activation function: {lp['activation']}")
if lp['input_size'] <= 0 or lp['output_size'] <= 0: raise ValueError("Layer dimensions must be positive integers")
l.append(lp)
if not l: l.append({'input_size': 128, 'output_size': 256, 'activation': 'relu'})
return l
def cmf(folder_path: str) -> DM:
s = defaultdict(list)
if not os.path.exists(folder_path):
print(colored(f"Warning: Folder {folder_path} does not exist. Creating model with default configuration.", 'yellow'))
return DM({})
xf = True
for r, d, f in os.walk(folder_path):
for file in f:
if file.endswith('.xml'):
xf = True
fp = os.path.join(r, file)
try:
l = px(fp)
sn = os.path.basename(r).replace('.', '_')
s[sn].extend(l)
except Exception as e:
print(colored(f"Error processing {fp}: {str(e)}", 'red'))
if not xf:
print(colored("Warning: No XML files found. Creating model with default configuration.", 'yellow'))
return DM({})
return DM(dict(s))
def ceas(folder_path: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
t = AutoTokenizer.from_pretrained(model_name)
m = AutoModel.from_pretrained(model_name)
embeddings = []
ds = []
for r, d, f in os.walk(folder_path):
for file in f:
if file.endswith('.xml'):
fp = os.path.join(r, file)
try:
tree = ET.parse(fp)
root = tree.getroot()
for e in root.iter():
if e.text:
text = e.text.strip()
i = t(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
emb = m(**i).last_hidden_state.mean(dim=1).numpy()
embeddings.append(emb)
ds.append(text)
except Exception as e:
print(colored(f"Error processing {fp}: {str(e)}", 'red'))
embeddings = np.vstack(embeddings)
return embeddings, ds
def qvs(query: str, embeddings, ds, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
t = AutoTokenizer.from_pretrained(model_name)
m = AutoModel.from_pretrained(model_name)
i = t(query, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
qe = m(**i).last_hidden_state.mean(dim=1).numpy()
similarities = cosine_similarity(qe, embeddings)
top_k_indices = similarities[0].argsort()[-5:][::-1]
return [ds[i] for i in top_k_indices]
def main():
fp = 'data'
m = cmf(fp)
print(colored(f"Created dynamic PyTorch model with sections: {list(m.s.keys())}", 'green'))
fs = next(iter(m.s.keys()))
fl = m.s[fs][0]
ife = fl[0].in_features
si = torch.randn(1, ife)
o = m(si)
print(colored(f"Sample output shape: {o.shape}", 'green'))
embeddings, ds = ceas(fp)
a = Accelerator()
o = torch.optim.Adam(m.parameters(), lr=0.001)
c = nn.CrossEntropyLoss()
ne = 10
d = torch.utils.data.TensorDataset(torch.randn(100, ife), torch.randint(0, 2, (100,)))
td = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True)
m, o, td = a.prepare(m, o, td)
for e in range(ne):
m.train()
tl = 0
for bi, (i, l) in enumerate(td):
o.zero_grad()
o = m(i)
l = c(o, l)
a.backward(l)
o.step()
tl += l.item()
al = tl / len(td)
print(colored(f"Epoch {e+1}/{ne}, Average Loss: {al:.4f}", 'blue'))
uq = "example query text"
r = qvs(uq, embeddings, ds)
print(colored(f"Query results: {r}", 'magenta'))
if __name__ == "__main__":
main()