sqlbot / app.py
barathm2001's picture
Upload 3 files
517b429 verified
Raw
History Blame Contribute Delete
2.62 kB
import logging
import os
from fastapi import FastAPI, HTTPException
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Global variables for model, tokenizer, and pipeline
model = None
tokenizer = None
pipe = None
# Get the Hugging Face token from environment variable
hf_token = os.environ.get("HUGGINGFACE_TOKEN")
if not hf_token:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
@app.on_event("startup")
async def load_model():
global model, tokenizer, pipe
try:
logger.info("Loading PEFT configuration...")
config = PeftConfig.from_pretrained("frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval", token=hf_token)
logger.info("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", token=hf_token)
logger.info("Loading PEFT model...")
model = PeftModel.from_pretrained(base_model, "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval", token=hf_token)
logger.info("Loading tokenizer...")
tokenizer = MistralTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", token=hf_token)
logger.info("Creating pipeline...")
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
logger.info("Model, tokenizer, and pipeline loaded successfully.")
except ImportError as e:
logger.error(f"Error importing required modules. Please check your installation: {e}")
raise
except Exception as e:
logger.error(f"Error loading model or creating pipeline: {e}")
raise
@app.get("/")
def home():
return {"message": "Hello World"}
@app.get("/generate")
async def generate(text: str):
if not pipe:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
output = pipe(text, max_length=100, num_return_sequences=1)
return {"output": output[0]['generated_text']}
except Exception as e:
logger.error(f"Error during text generation: {e}")
raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)