SabziAi / app.py
apaxray's picture
Update app.py
43454ac verified
raw
history blame
1.61 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
from sympy import solve, symbols
from fastapi import FastAPI
import uvicorn
# مدل‌های مختلف
MODEL_GENERAL = "meta-llama/Llama-2-7b-chat-hf"
MODEL_IRAN = "HooshvareLab/bert-fa-base-uncased"
MODEL_MATH = None # SymPy برای ریاضی
# بارگذاری مدل‌ها
tokenizer_general = AutoTokenizer.from_pretrained(MODEL_GENERAL)
model_general = AutoModelForCausalLM.from_pretrained(MODEL_GENERAL)
tokenizer_iran = AutoTokenizer.from_pretrained(MODEL_IRAN)
model_iran = AutoModelForCausalLM.from_pretrained(MODEL_IRAN)
# FastAPI برای مدیریت درخواست‌ها
app = FastAPI()
def generate_response(model, tokenizer, prompt, max_tokens=100):
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_new_tokens=max_tokens)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
@app.post("/chat")
def chat(input_text: str, mode: str = "general"):
if mode == "general":
response = generate_response(model_general, tokenizer_general, input_text)
elif mode == "iran":
response = generate_response(model_iran, tokenizer_iran, input_text)
elif mode == "math":
x = symbols("x")
try:
solution = solve(input_text, x)
response = f"Solution: {solution}"
except Exception as e:
response = f"Math error: {str(e)}"
else:
response = "Invalid mode selected."
return {"response": response}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)