Hopefully someone out there knows how to get the configured API working and running in Gradio. No matter what I do I cannot get it to load my code base into the server and cannot connect to it using Gradio. This is how I am trying to connect and load the server currently:
%%writefile clean_server.py
import threading
import asyncio
import uvicorn
import nest_asyncio
from fastapi import FastAPI
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langserve import add_routes
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_core.runnables import RunnableLambda
from operator import itemgetter
from pydantic import BaseModel, model_validator
import pprint
# Allow asyncio to run multiple times in Jupyter notebook
nest_asyncio.apply()
# Load models
instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x22b-instruct-v0.1")
embedder = NVIDIAEmbeddings(model="nvidia/nv-embed-v1", truncate="END")
# Load FAISS vectorstore
vectorstore = FAISS.load_local(
"docstore_index",
embeddings=embedder,
index_name="index",
allow_dangerous_deserialization=True
)
retriever = vectorstore.as_retriever()
# Test vectorstore
try:
test_docs = retriever.get_relevant_documents("test")
print(f"✅ Vectorstore loaded with {len(test_docs)} test docs")
except Exception as e:
print("❌ Vectorstore load failed:", str(e))
# Create FastAPI app
app = FastAPI(
title="LangChain Server",
version="1.0",
description="Final LangServe RAG server for evaluation"
)
from fastapi.middleware.cors import CORSMiddleware
# Add CORS middleware to allow communication with different origins (Gradio frontend)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all domains for testing; modify this for production
allow_credentials=True,
allow_methods=["*"], # Allow all HTTP methods (POST, GET, etc.)
allow_headers=["*"], # Allow all headers
)
# Input coercion helper
def clean_input(raw):
if isinstance(raw, str):
return raw
elif isinstance(raw, dict):
return raw.get("input") or raw.get("question") or str(raw)
return str(raw)
# Optional: Model for /retriever route
class RetrieverInput(BaseModel):
input: str
@model_validator(mode='before')
def coerce_input(cls, values):
if isinstance(values, str):
return {"input": values}
if isinstance(values, dict) and "input" in values:
return values
raise ValueError("Input must be a string or a dict with an 'input' field.")
# Wrap input for retriever route
def get_query_input(x):
try:
return RetrieverInput.parse_obj(x).input
except Exception:
return clean_input(x)
retriever_chain = RunnableLambda(lambda x: retriever.invoke(get_query_input(x)))
# Define the retriever route
@app.post("/retriever")
async def retriever_route(payload: dict):
input_query = clean_input(payload)
result = retriever.invoke(input_query)
return {"result": result}
# RAG Chain
rag_chain = RetrievalQA.from_chain_type(
llm=instruct_llm,
retriever=retriever,
return_source_documents=True
)
# Wrap RAG with clean input and logging
def debug_rag_input(x):
print("🔍 Incoming input to /generator route:")
print("🔎 Type:", type(x))
pprint.pprint(x)
question = clean_input(x)
print("🧪 Cleaned question:", question)
try:
result = rag_chain.invoke(question)
print("✅ RAG chain completed.")
return result
except Exception as e:
print("❌ RAG chain error:", str(e))
return {"error": str(e)}
generator_chain = RunnableLambda(debug_rag_input)
# Define the generator route
@app.post("/generator")
async def generator_route(payload: dict):
question = clean_input(payload)
result = rag_chain.invoke(question)
return {"result": result}
# Health check route
@app.get("/health")
def health_check():
return {"status": "ok"}
# Echo route for debugging payloads
add_routes(app, RunnableLambda(lambda x: {"echo": x}), path="/echo")
# Start the server in the background thread
def start_server():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) # Ensure correct event loop is used
print("🚀 Starting FastAPI server on port 8090...") # Added log for debugging
uvicorn.run(app, host="0.0.0.0", port=8090, log_level="info", loop="asyncio")
# Start the server in a separate thread to avoid blocking Jupyter
server_thread = threading.Thread(target=start_server, daemon=True)
server_thread.start()
When I try to simply check the server it will not load.
!curl http://localhost:8090/health
I have tried and killed the process on the port as well, confirmed it was killed and tried to reload. Nothing I do gets the server to load. When I run the code locally in the notebook I do not have this issue with being able to connect.
Any help would be deeply appreciated.