File size: 2,184 Bytes
8aba593
3e30009
0864a62
4a4b682
8aba593
3e30009
4a4b682
291341c
8aba593
4a4b682
0864a62
 
 
 
 
 
 
 
 
 
4a4b682
0864a62
4a4b682
 
 
 
 
 
8aba593
0864a62
 
4a4b682
0864a62
4a4b682
 
 
 
 
 
8aba593
0864a62
 
 
 
 
 
 
4a4b682
 
0864a62
4a4b682
0864a62
 
 
 
 
 
 
4a4b682
 
 
0864a62
4a4b682
0864a62
4a4b682
 
 
 
8aba593
0864a62
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from PIL import Image
from ultralytics import YOLO
from transformers import pipeline
import torch

# β€”β€”β€” 1) Chart-pattern detector (YOLOv8) β€”β€”β€”
pattern_model = YOLO("model.pt")

# β€”β€”β€” 2) Trading-Hero-LLM pipeline β€”β€”β€”
qa_pipeline = pipeline(
    "text-generation",
    model="fuchenru/Trading-Hero-LLM",
    max_new_tokens=100,
    device=0 if torch.cuda.is_available() else -1
)

def analyze_charts(img1: Image.Image, img2: Image.Image):
    output = []

    # Chart 1
    if img1:
        res1 = pattern_model(img1)[0]
        labels1 = [res1.names[int(b.cls[0])] for b in res1.boxes]
        output.append(
            "πŸ“Š Chart 1 Patterns:\n" +
            ("\n".join(f"β€’ {lbl}" for lbl in labels1) if labels1 else "No patterns detected.")
        )
    else:
        output.append("πŸ–ΌοΈ Chart 1: No image uploaded.")

    # Chart 2
    if img2:
        res2 = pattern_model(img2)[0]
        labels2 = [res2.names[int(b.cls[0])] for b in res2.boxes]
        output.append(
            "\nπŸ“Š Chart 2 Patterns:\n" +
            ("\n".join(f"β€’ {lbl}" for lbl in labels2) if labels2 else "No patterns detected.")
        )
    else:
        output.append("\nπŸ–ΌοΈ Chart 2: No image uploaded.")

    return "\n".join(output)

def answer_question(question: str):
    if not question.strip():
        return "❌ Please enter a question."
    resp = qa_pipeline(question)[0]["generated_text"]
    return resp

# β€”β€”β€” Gradio UI β€”β€”β€”
with gr.Blocks() as demo:
    gr.Markdown("## πŸ“ˆ Nifty AI Trading Assistant")

    with gr.Row():
        img1 = gr.Image(type="pil", label="Upload Chart 1")
        img2 = gr.Image(type="pil", label="Upload Chart 2")

    analyze_btn = gr.Button("πŸ” Analyze Charts")
    pattern_out = gr.Textbox(label="Chart Pattern Output")
    analyze_btn.click(fn=analyze_charts, inputs=[img1, img2], outputs=pattern_out)

    gr.Markdown("---")

    question = gr.Textbox(label="πŸ’¬ Ask a Trading Question")
    answer_btn = gr.Button("πŸ€– Get LLM Response")
    llm_out = gr.Textbox(label="LLM Answer")
    answer_btn.click(fn=answer_question, inputs=question, outputs=llm_out)

demo.launch()