tanya17 commited on
Commit
fdf9ba3
·
verified ·
1 Parent(s): 988badf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -15,11 +15,15 @@ import random
15
  import requests
16
  import os
17
 
18
- from transformers import pipeline, set_seed
19
 
20
- # Load a text generation model locally
21
- generator = pipeline('text-generation', model='gpt2')
22
- set_seed(42)
 
 
 
 
23
 
24
  # File to store feedback
25
  FEEDBACK_FILE = "user_feedback.csv"
@@ -29,6 +33,10 @@ def huggingface_chatbot(user_input):
29
  result = generator(user_input, max_length=150, temperature=0.7, do_sample=True)
30
  if isinstance(result, list) and "generated_text" in result[0]:
31
  return result[0]["generated_text"]
 
 
 
 
32
  else:
33
  return "⚠️ Could not parse model response."
34
  except Exception as e:
@@ -36,6 +44,7 @@ def huggingface_chatbot(user_input):
36
 
37
 
38
 
 
39
  # Database setup for user authentication
40
  def init_db():
41
  conn = sqlite3.connect("users.db")
 
15
  import requests
16
  import os
17
 
18
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
19
 
20
+ # Load tokenizer and model for Flan-T5
21
+ model_name = "google/flan-t5-base"
22
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
23
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
24
+
25
+ # Create a pipeline
26
+ generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
27
 
28
  # File to store feedback
29
  FEEDBACK_FILE = "user_feedback.csv"
 
33
  result = generator(user_input, max_length=150, temperature=0.7, do_sample=True)
34
  if isinstance(result, list) and "generated_text" in result[0]:
35
  return result[0]["generated_text"]
36
+ elif "generated_text" in result:
37
+ return result["generated_text"]
38
+ elif "text" in result[0]:
39
+ return result[0]["text"]
40
  else:
41
  return "⚠️ Could not parse model response."
42
  except Exception as e:
 
44
 
45
 
46
 
47
+
48
  # Database setup for user authentication
49
  def init_db():
50
  conn = sqlite3.connect("users.db")