ModelDeployment

Source code in madewithml/serve.py
@serve.deployment(route_prefix="/", num_replicas="1")
@serve.ingress(app)
class ModelDeployment:
    def __init__(self, run_id: str, threshold: int = 0.9):
        """Initialize the model."""
        self.run_id = run_id
        self.threshold = threshold
        mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)  # so workers have access to model registry
        best_checkpoint = predict.get_best_checkpoint(run_id=run_id)
        self.predictor = TorchPredictor.from_checkpoint(best_checkpoint)
        self.preprocessor = self.predictor.get_preprocessor()

    @app.get("/")
    def _index(self) -> Dict:
        """Health check."""
        response = {
            "message": HTTPStatus.OK.phrase,
            "status-code": HTTPStatus.OK,
            "data": {},
        }
        return response

    @app.get("/run_id/")
    def _run_id(self) -> Dict:
        """Get the run ID."""
        return {"run_id": self.run_id}

    @app.post("/evaluate/")
    async def _evaluate(self, request: Request) -> Dict:
        data = await request.json()
        results = evaluate.evaluate(run_id=self.run_id, dataset_loc=data.get("dataset_loc"))
        return {"results": results}

    @app.post("/predict/")
    async def _predict(self, request: Request) -> Dict:
        # Get prediction
        data = await request.json()
        df = pd.DataFrame([{"title": data.get("title", ""), "description": data.get("description", ""), "tag": ""}])
        results = predict.predict_with_proba(df=df, predictor=self.predictor)

        # Apply custom logic
        for i, result in enumerate(results):
            pred = result["prediction"]
            prob = result["probabilities"]
            if prob[pred] < self.threshold:
                results[i]["prediction"] = "other"

        return {"results": results}

__init__(run_id, threshold=0.9)

Initialize the model.

madewithml/serve.py
def __init__(self, run_id: str, threshold: int = 0.9):
    """Initialize the model."""
    self.run_id = run_id
    self.threshold = threshold
    mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)  # so workers have access to model registry
    best_checkpoint = predict.get_best_checkpoint(run_id=run_id)
    self.predictor = TorchPredictor.from_checkpoint(best_checkpoint)
    self.preprocessor = self.predictor.get_preprocessor()