howard-hou commited on
Commit
8c28418
1 Parent(s): 2833bac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -39,16 +39,16 @@ def generate(
39
  ctx,
40
  image_features,
41
  token_count=200,
42
- temperature=0.2,
43
  top_p=0.3,
44
  presencePenalty = 0.1,
45
  countPenalty = 0.1,
46
  ):
47
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
48
- alpha_frequency = countPenalty,
49
- alpha_presence = presencePenalty,
50
- token_ban = [], # ban the generation of some tokens
51
- token_stop = [0, 261]) # stop generation whenever you see any token here
52
  ctx = ctx.strip()
53
  all_tokens = []
54
  out_last = 0
@@ -56,9 +56,11 @@ def generate(
56
  occurrence = {}
57
  for i in range(int(token_count)):
58
  if i == 0:
 
 
59
  input_ids = pipeline.encode(ctx)
60
  text_embs = model.w['emb.weight'][input_ids]
61
- input_embs = torch.cat((image_features, text_embs), dim=0)[-ctx_limit:]
62
  out, state = model.forward(embs=input_embs, state=None)
63
  else:
64
  input_ids = [token]
 
39
  ctx,
40
  image_features,
41
  token_count=200,
42
+ temperature=1.0,
43
  top_p=0.3,
44
  presencePenalty = 0.1,
45
  countPenalty = 0.1,
46
  ):
47
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
48
+ alpha_frequency = countPenalty,
49
+ alpha_presence = presencePenalty,
50
+ token_ban = [], # ban the generation of some tokens
51
+ token_stop = [0, 261]) # stop generation whenever you see any token here
52
  ctx = ctx.strip()
53
  all_tokens = []
54
  out_last = 0
 
56
  occurrence = {}
57
  for i in range(int(token_count)):
58
  if i == 0:
59
+ prefix_ids = pipeline.encode("User: ")
60
+ prefix_embs = model.w['emb.weight'][prefix_ids]
61
  input_ids = pipeline.encode(ctx)
62
  text_embs = model.w['emb.weight'][input_ids]
63
+ input_embs = torch.cat((prefix_embs, image_features, text_embs), dim=0)[-ctx_limit:]
64
  out, state = model.forward(embs=input_embs, state=None)
65
  else:
66
  input_ids = [token]