Spaces:
Runtime error
Runtime error
import torch | |
def extract_text_feature(prompt, model, processor, device="cpu"): | |
"""Extract text features | |
Args: | |
prompt: a single text query | |
model: OwlViT model | |
processor: OwlViT processor | |
device (str, optional): device to run. Defaults to 'cpu'. | |
""" | |
device = "cpu" | |
if torch.cuda.is_available(): | |
device = "cuda" | |
with torch.no_grad(): | |
input_ids = torch.as_tensor(processor(text=prompt)["input_ids"]).to(device) | |
print(input_ids.device) | |
text_outputs = model.owlvit.text_model( | |
input_ids=input_ids, | |
attention_mask=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
) | |
text_embeds = text_outputs[1] | |
text_embeds = model.owlvit.text_projection(text_embeds) | |
text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6 | |
query_embeds = text_embeds | |
return input_ids, query_embeds | |
def prompt2vec(prompt: str, model, processor): | |
"""Convert prompt into a computational vector | |
Args: | |
prompt (str): Text to be tokenized | |
Returns: | |
xq: vector from the tokenizer, representing the original prompt | |
""" | |
# inputs = tokenizer(prompt, return_tensors='pt') | |
# out = clip.get_text_features(**inputs) | |
input_ids, xq = extract_text_feature(prompt, model, processor) | |
input_ids = input_ids.detach().cpu().numpy() | |
xq = xq.detach().cpu().numpy() | |
return input_ids, xq | |
def tune(clf, X, y, iters=2): | |
"""Train the Zero-shot Classifier | |
Args: | |
X (numpy.ndarray): Input vectors (retreived vectors) | |
y (list of floats or numpy.ndarray): Scores given by user | |
iters (int, optional): iterations of updates to be run | |
""" | |
assert len(X) == len(y) | |
# train the classifier | |
clf.fit(X, y, iters=iters) | |
# extract new vector | |
return clf.get_weights() | |
class Classifier: | |
"""Multi-Class Zero-shot Classifier | |
This Classifier provides proxy regarding to the user's reaction to the probed images. | |
The proxy will replace the original query vector generated by prompted vector and finally | |
give the user a satisfying retrieval result. | |
This can be commonly seen in a recommendation system. The classifier will recommend more | |
precise result as it accumulating user's activity. | |
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes | |
and the last one takes the negative one. | |
""" | |
def __init__(self, client, obj_db:str, xq: list): | |
init_weight = torch.Tensor(xq) | |
self.num_class = xq.shape[0] | |
self.DIMS = xq.shape[1] | |
# convert initial query `xq` to tensor parameter to init weights | |
self.weight = init_weight | |
self.client = client | |
self.obj_db = obj_db | |
def fit(self, X: list, y: list, iters: int = 5): | |
# convert X and y to tensor | |
xq_s = [ | |
f"[{', '.join([str(float(fnum)) for fnum in _xq + [1]])}]" | |
for _xq in self.get_weights().tolist() | |
] | |
for _ in range(iters): | |
# zero gradients | |
grad = [] | |
# Normalize the weight before inference | |
# This will constrain the gradient or you will have an explosion on query vector | |
self.weight.data /= torch.norm( | |
self.weight.data, p=2, dim=-1, keepdim=True | |
) | |
for n in range(self.num_class): | |
# select all training sample and create labels | |
labels, objs = list(map(list, zip(*[[1 if y[i]==n else 0, x] for i, x in enumerate(X) if y[i] in [n, self.num_class+1]]))) | |
# NOTE from @fangruil | |
# Use SQL to calculate the gradient | |
# For binary cross entropy we have | |
# g = (1/(1+\exp(-XW))-Y)^TX | |
# To simplify the query, we separated | |
# the calculation into class numbers | |
grad_q_str = f""" | |
SELECT sumForEachArray(arrayMap((x,y,gt)->arrayMap(i->i*(y-gt), x), X, Y, GT)) AS grad | |
FROM ( | |
SELECT groupArray(arrayPopBack(prelogit)) AS X, | |
groupArray(1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {xq_s[n]}))))) AS Y, {labels} AS GT | |
FROM {self.obj_db} WHERE obj_id IN {objs})""" | |
grad.append(torch.as_tensor(self.client.fetch(grad_q_str)[0]['grad'])) | |
# update weights | |
grad = torch.stack(grad, dim=0) | |
self.weight -= 0.1 * grad | |
def get_weights(self): | |
xq = self.weight.detach().numpy() | |
return xq | |
class SplitLayer(torch.nn.Module): | |
def forward(self, x): | |
return torch.split(x, 1, dim=-1) | |