| |
| import gradio as gr |
| import joblib |
| import numpy as np |
| from collections import Counter |
| from typing import List |
| import os |
|
|
| |
| BASES = ['A', 'T', 'C', 'G'] |
|
|
| def kmer_counts(seq: str, k=3): |
| seq = seq.strip().upper() |
| counts = Counter() |
| if len(seq) < k: |
| return counts |
| for i in range(len(seq) - k + 1): |
| counts[seq[i:i+k]] += 1 |
| return counts |
|
|
| def vectorize_single(seq: str, vocab: List[str], k=3): |
| X = np.zeros((1, len(vocab)), dtype=float) |
| c = kmer_counts(seq, k) |
| for j, kmer in enumerate(vocab): |
| X[0, j] = c.get(kmer, 0) |
| return X |
|
|
| |
| MODEL_PATH = "mutation_model.joblib" |
|
|
| if not os.path.exists(MODEL_PATH): |
| raise FileNotFoundError( |
| f"⚠️ Model file '{MODEL_PATH}' not found. " |
| "Please upload 'mutation_model.joblib' along with this app." |
| ) |
|
|
| model, vocab = joblib.load(MODEL_PATH) |
|
|
| |
| def predict_sequence(sequence: str): |
| if not sequence or len(sequence.strip()) < 3: |
| return {"error": "Please enter a valid DNA sequence (≥3 bases)."} |
|
|
| X = vectorize_single(sequence, vocab=vocab, k=3) |
| pred = model.predict(X)[0] |
| prob = float(model.predict_proba(X).max()) if hasattr(model, "predict_proba") else None |
|
|
| return { |
| "sequence": sequence, |
| "mutation_detected": bool(pred), |
| "confidence": round(prob, 3) if prob else "N/A" |
| } |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| """ |
| <h1 style="text-align:center;">🧬 DNA Mutation Analyzer</h1> |
| <p style="text-align:center;"> |
| Enter a DNA sequence to check for mutations using the ML model. |
| </p> |
| """ |
| ) |
| |
| with gr.Row(): |
| seq_input = gr.Textbox( |
| label="DNA Sequence", |
| placeholder="Enter sequence like ATGCGTACGTTAGC...", |
| lines=2, |
| ) |
| analyze_btn = gr.Button("🔍 Analyze Sequence") |
| result = gr.JSON(label="Analysis Result") |
| |
| analyze_btn.click(fn=predict_sequence, inputs=seq_input, outputs=result) |
|
|
| |
| def api_predict(payload: dict): |
| seq = payload.get("sequence", "") |
| return predict_sequence(seq) |
|
|
| if __name__ == "__main__": |
| |
| demo.launch(share=True, ssr_mode=False) |