Ajay12345678980 commited on
Commit
b231aba
·
verified ·
1 Parent(s): a0c2a3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -16
app.py CHANGED
@@ -11,22 +11,36 @@ model = GPT2LMHeadModel.from_pretrained(model_repo_id)
11
  tokenizer = GPT2Tokenizer.from_pretrained(model_repo_id)
12
 
13
  # Define the prediction function
14
- def predict(text):
15
- try:
16
- # Encode the input text
17
- inputs = tokenizer.encode(text, return_tensors="pt")
18
-
19
- # Generate output using the model
20
- with torch.no_grad():
21
- outputs = model.generate(inputs, max_length=50, do_sample=True)
22
-
23
- # Decode the generated output
24
- prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
-
26
- return prediction.strip() # Return the clean output
27
- except Exception as e:
28
- # Handle and print any exceptions for debugging
29
- return f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Gradio interface setup
32
  interface = gr.Interface(
 
11
  tokenizer = GPT2Tokenizer.from_pretrained(model_repo_id)
12
 
13
  # Define the prediction function
14
+ def generate_answer(question):
15
+ input_ids = tokenizer.encode(question, return_tensors="pt").to("cuda")
16
+
17
+ # Create the attention mask and pad token id
18
+ attention_mask = torch.ones_like(input_ids).to("cuda")
19
+ pad_token_id = tokenizer.eos_token_id
20
+
21
+ #output = model[0].generate(
22
+ output = model.generate(
23
+ input_ids,
24
+ max_new_tokens=100,
25
+ num_return_sequences=1,
26
+ attention_mask=attention_mask,
27
+ pad_token_id=pad_token_id
28
+ )
29
+ decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
30
+ start_index = decoded_output.find("Answer")
31
+ end_index = decoded_output.find("<ANSWER_ENDED>")
32
+
33
+ if end_index != -1:
34
+ # Extract the text between "Answer" and "<ANSWER_ENDED>"
35
+ answer_text = decoded_output[start_index + len("Answer"):end_index].strip()
36
+ return answer_text
37
+ else:
38
+ # If "<ANSWER_ENDED>" is not found, return the text following "Answer"
39
+ answer_text = decoded_output[start_index + len("Answer"):].strip()
40
+ return answer_text
41
+
42
+ #return tokenizer.decode(output[0], skip_special_tokens=True)
43
+ #return tokenizer.decode(output, skip_special_tokens=True)
44
 
45
  # Gradio interface setup
46
  interface = gr.Interface(